In [4]:
%pip install pytorch-metric-learning
%pip install torch-summary

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [5]:
import os
import json
import zipfile
import subprocess
import shutil
import getpass
import math
import numpy
import pandas
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image,ImageReadMode
import matplotlib.pyplot as plt
import plotly.express as px
from pytorch_metric_learning import losses, regularizers
from torchsummary import summary
from torch.nn.utils.prune import l1_unstructured, random_unstructured
import torch.nn.utils.prune as prune
import copy

In [6]:
torch.manual_seed(40)

<torch._C.Generator at 0x7faeda17f310>

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
print(f"Using {device} device")

Using cuda device


In [8]:
datasetA_save_dir = './datasetA'
datasetB_save_dir = './datasetB'

# Model Loading and training configurations settings
load_best_first = True
keep_initialization_records = True
start_training_on_best_weights = True # Start Training with best weights

# File Paths
initial_best_model_path = './best_model@68.99%.pth'
save_current_best_model_path = './current_best_model.pth'
save_current_initializations = './current_initializations.pth'
save_validation_accuracy_path = './val_acc.npy'

In [13]:
def one_hot_encode(val):
    arr = numpy.zeros((6,), dtype=int)
    arr[val] = 1
    return arr

def get_bucket_id(age):
  age_floor = int(age)
  if age_floor >= 0 and age_floor <= 5: return 0
  elif age_floor >= 6 and age_floor <= 12: return 1
  elif age_floor >= 13 and age_floor <= 19: return 2
  elif age_floor >= 20 and age_floor <= 29: return 3
  elif age_floor >= 30 and age_floor <= 59: return 4
  else: return 5

def get_ground_truth(age):
  return one_hot_encode(get_bucket_id(age))

In [14]:
def train(dataloader, model, optimizer, loss_function, loss_optimizer):
    torch.cuda.empty_cache()
    size = len(dataloader.dataset)
    model.train()
    loss_tot = 0.0
    num = 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Zeroing the gradients
        optimizer.zero_grad()
        loss_optimizer.zero_grad()

        # Get Embeddings
        embeddings = model(X)

        # Calculate Loss
        loss = loss_function(embeddings, y.argmax(1))

        # Backpropagation
        loss.backward()

        # Update
        optimizer.step()
        loss_optimizer.step()

        loss_tot += loss.item()
        num += 1

        del X
        del y
        del embeddings
        del loss
        torch.cuda.empty_cache()

    # loss_tot /= num
    print(f'training loss: {(loss_tot):>0.5f}')

In [15]:
def train_with_pseudo_label(dataloader, model, optimizer, loss_function, loss_optimizer, best_model):
    torch.cuda.empty_cache()
    size = len(dataloader.dataset)
    model.train()
    loss_tot = 0.0
    num = 0
    tmp = 0
    for batch, X in enumerate(dataloader):
        X = X.to(device)
        tmp = tmp + X.shape[0]

        # Pseudo-Labelling
        embedding = best_model(X)
        logits = best_model.margin_loss_function.get_logits(embedding)
        predSoftmax = nn.Softmax(dim=1)(logits)
        pseudo_label = predSoftmax.argmax(1)

        # Zeroing the gradients
        optimizer.zero_grad()
        loss_optimizer.zero_grad()

        # Get Embeddings
        embeddings = model(X)

        # Calculate Loss
        loss = loss_function(embeddings, pseudo_label)

        # Backpropagation
        loss.backward()

        # Update
        optimizer.step()
        loss_optimizer.step()

        loss_tot += loss.item()
        num += 1

        del X
        del embedding
        del logits
        del predSoftmax
        del pseudo_label
        del embeddings
        del loss
        torch.cuda.empty_cache()
    # loss_tot /= num
    print(f'pseudo-label training loss: {(loss_tot):>0.5f}')
    print(f'total-unlabelled-training: {tmp}')

In [16]:
def save_model_checkpoint(model, path):
    torch.save({
      'model_state_dict': model.state_dict(),
      'arcface_state_dict': model.margin_loss_function.state_dict(),
    }, path)

In [17]:
def deep_copy_model(model_to_copy):
    model_copy = copy.deepcopy(model_to_copy)
    return model_copy

In [18]:
validation_accuracy = []
current_max_val_acc = 0.0
def validation(dataloader, model, loss_function, dont_log = False):
    global current_max_val_acc
    global best_model
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    correct = 0
    totalsize = 0
    loss_tot = 0.0
    num = 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            # Forward With Logits
            embeddings = model(X)
            logits = model.margin_loss_function.get_logits(embeddings)
            predSoftmax = nn.Softmax(dim=1)(logits)
            loss = model.margin_loss_function(embeddings,y.argmax(1))

            correct += (predSoftmax.argmax(1) == y.argmax(1)).sum().item()
            totalsize += predSoftmax.shape[0]
            loss_tot += loss.item()
            num += 1
            X.cpu()
            y.cpu()

    print(f"Correct/Total: {correct}/{totalsize}")
    correct /= totalsize
    validation_accuracy.append(correct*100)
    print(f"Validation Loss:  {(loss_tot):>0.5f}")
    print(f"Validation Accuracy: {(100*correct):>0.5f}%\n")

    if 100*correct > current_max_val_acc and dont_log == False:
        current_max_val_acc = 100*correct
        save_model_checkpoint(model, save_current_best_model_path)
        numpy.save(save_validation_accuracy_path, numpy.array(validation_accuracy), allow_pickle=True, fix_imports=True)
        if correct > 67.0:
            best_model = copy.deepcopy(model)
            print("Change best_model")
        print("Saved")

    print(f"Current Best Validation Accuracy: {(current_max_val_acc):>0.5f}%\n")
    return loss_tot, correct

In [19]:
class XRayToothDatasetA(Dataset):
    def __init__(self, cwd, img_dir, transform=None, target_height=None, target_width=None):
        self.dataset_path = cwd + '/' + img_dir
        self.transform = transform
        self.target_height = target_height
        self.target_width = target_width

    def __len__(self):
        return len(os.listdir(self.dataset_path))

    def __getitem__(self, idx):
        if idx  >= len(os.listdir(self.dataset_path)):
            print("No datafile/image at index : "+ str(idx))
            return None
        img_filename = os.listdir(self.dataset_path)[idx]
        age = float(img_filename.split("_")[1][:-4])
        age_gt = get_ground_truth(age)
        image_tensor = read_image(path=self.dataset_path + '/' + img_filename)
        image_tensor = image_tensor.reshape(1, 3, image_tensor.shape[-2], image_tensor.shape[-1])
        if self.target_height and self.target_width: # Resize the image
            image_tensor = torch.nn.functional.interpolate(image_tensor, (self.target_height,self.target_width))
        if self.transform: image_tensor = self.transform(image_tensor) # Apply transformations
        image_tensor = (image_tensor-image_tensor.min())/(image_tensor.max()-image_tensor.min())
        return image_tensor.reshape(-1,image_tensor.shape[-2],image_tensor.shape[-1]).to(torch.float32), torch.tensor(age_gt)

In [20]:
# Gives only the unlabelled data
class XRayToothDatasetB(Dataset):
    def __init__(self, cwd, img_dir, transform=None, target_height=None, target_width=None):
        self.dataset_path = cwd + '/' + img_dir
        self.transform = transform
        self.target_height = target_height
        self.target_width = target_width

    def __len__(self):
        return len(os.listdir(self.dataset_path))

    def __getitem__(self, idx):
        if idx  >= len(os.listdir(self.dataset_path)):
            print("No datafile/image at index : "+ str(idx))
            return None
        img_filename = os.listdir(self.dataset_path)[idx]
        image_tensor = read_image(path=self.dataset_path + '/' + img_filename)
        image_tensor = image_tensor.reshape(1, 3, image_tensor.shape[-2], image_tensor.shape[-1])
        if self.target_height and self.target_width: # Resize the image
            image_tensor = torch.nn.functional.interpolate(image_tensor, (self.target_height,self.target_width))
        if self.transform: image_tensor = self.transform(image_tensor) # Apply transformations
        image_tensor = (image_tensor-image_tensor.min())/(image_tensor.max()-image_tensor.min())
        return image_tensor.reshape(-1,image_tensor.shape[-2],image_tensor.shape[-1]).to(torch.float32)

In [22]:
training_data = XRayToothDatasetA(os.getcwd(), img_dir=datasetA_save_dir+'/training', transform=None, target_height=224, target_width=224)
validation_data = XRayToothDatasetA(os.getcwd(), img_dir=datasetA_save_dir+'/validation', transform=None, target_height=224, target_width=224)
training_data_unlabelled = XRayToothDatasetB(os.getcwd(), img_dir=datasetB_save_dir+'/training', transform=None, target_height=224, target_width=224)

In [23]:
from torchvision.models import vit_l_32, ViT_L_32_Weights

pretrained_vit = vit_l_32(weights=ViT_L_32_Weights.IMAGENET1K_V1)

In [24]:
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights

pretrained_effnet = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)

In [25]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone1 = pretrained_vit
        self.backbone2 = pretrained_effnet

        for param in self.backbone1.parameters():
            param.requires_grad = False

        for param in self.backbone2.parameters():
            param.requires_grad = False

        self.fc = nn.Sequential(
            nn.Dropout(0.6),
            nn.Linear(2000,512),
            nn.ReLU(),
            nn.Dropout(0.6),
            nn.Linear(512,256),
        )

        self.margin_loss_regularizer = regularizers.RegularFaceRegularizer()
        self.margin_loss_function = losses.ArcFaceLoss(6, 256, margin=34.3, scale=1, weight_regularizer=self.margin_loss_regularizer)

    def forward(self, x):
        x1 = self.backbone1(x)
        x2 = self.backbone2(x)
        x = torch.cat((x1,x2), 1)
        x = self.fc(x)
        return x

In [26]:
model = NeuralNetwork().to(device)

if keep_initialization_records == True: save_model_checkpoint(model, save_current_initializations) # Remember initializaion

if load_best_first == True:
    best_model = NeuralNetwork().to(device)
    best_model.load_state_dict(torch.load(initial_best_model_path)['model_state_dict'])
    best_model.margin_loss_function.load_state_dict(torch.load(initial_best_model_path)['arcface_state_dict'])
    print("Loaded best model in best_model.")
else:
    best_model = copy.deepcopy(model)
    print("Didn't load best_model")

if start_training_on_best_weights == True:
    model = deep_copy_model(best_model) # Start Training with best weights
    print("Training will start with best weights till now!")
else:
    print("Training will start with random weights.")

margin_loss_function = model.margin_loss_function

Loaded best model in best_model.
Training will start with best weights till now!


In [27]:
# Test a forward pass
with torch.no_grad():
    embeddings = model(training_data[0][0].reshape(-1,3,224,224).to(device))
    logits = model.margin_loss_function.get_logits(embeddings)
    predSoftmax = nn.Softmax(dim=1)(logits)
    print(predSoftmax)

tensor([[0.1083, 0.1208, 0.1318, 0.2247, 0.2789, 0.1356]], device='cuda:0')


In [28]:
# Training Hyperparameters
short_epoch_limit = 1000
batch_size = 500
learning_rate = 3.8742e-04
momentum=0.9
weight_decay=0.9

# Margin Loss Hyperparameters
margin_loss_learning_rate = 3.8742e-05

In [29]:
trainingA_data_loader = DataLoader(training_data, batch_size, shuffle = True)
validationA_data_loader = DataLoader(validation_data, batch_size, shuffle = False)
trainingB_data_loader = DataLoader(training_data_unlabelled, batch_size, shuffle = True)

In [30]:
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
loss_optimizer = torch.optim.Adam(model.margin_loss_function.parameters(), lr=margin_loss_learning_rate)

# Schedulers
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=5, min_lr=1e-7,verbose=True)
margin_loss_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(loss_optimizer, 'min', factor=0.9, patience=5, min_lr=1e-7,verbose=True)

In [33]:
current_max_val_acc = 0.0

In [34]:
significant_long_epochs_count = 0
def run_long_epoch(model, best_model):
    global significant_long_epochs_count
    print(f"Long Epoch {significant_long_epochs_count+1}\n-------------------------------")
    # Pseudo Label and Train with that pseudo labelling
    train_with_pseudo_label(trainingB_data_loader, model, optimizer, margin_loss_function, loss_optimizer, best_model)

    # Validate on Dataset A's Validation Samples
    val_loss,val_acc = validation(validationA_data_loader, model, margin_loss_function)
    # scheduler2.step(val_loss)
    # margin_loss_scheduler2.step(val_loss)

    significant_long_epochs_count = significant_long_epochs_count + 1

    return val_acc

In [None]:
for t in range(short_epoch_limit):
    print(f"Short Epoch {t+1}\n-------------------------------")
    # Train @ alpha on Dataset A's Training Samples
    train(trainingA_data_loader, model, optimizer, margin_loss_function, loss_optimizer)

    # Validate on Dataset A's Validation Samples
    val_loss,val_acc = validation(validationA_data_loader, model, margin_loss_function)
    # scheduler2.step(val_loss)
    # margin_loss_scheduler2.step(val_loss)

    # Run Long Epoch
    run_long_epoch(model, best_model)
print("Done!")

Short Epoch 1
-------------------------------
training loss: 4.26254
Correct/Total: 42/129
Validation Loss:  2.30492
Validation Accuracy: 32.55814%

Saved
Current Best Validation Accuracy: 32.55814%

Long Epoch 1
-------------------------------
pseudo-label training loss: 4.03958
total-unlabelled-training: 1000
Correct/Total: 38/129
Validation Loss:  2.29103
Validation Accuracy: 29.45736%

Current Best Validation Accuracy: 32.55814%

Short Epoch 2
-------------------------------
training loss: 4.23630
Correct/Total: 41/129
Validation Loss:  2.13293
Validation Accuracy: 31.78295%

Current Best Validation Accuracy: 32.55814%

Long Epoch 2
-------------------------------
pseudo-label training loss: 3.72483
total-unlabelled-training: 1000
Correct/Total: 45/129
Validation Loss:  2.09007
Validation Accuracy: 34.88372%

Saved
Current Best Validation Accuracy: 34.88372%

Short Epoch 3
-------------------------------
training loss: 4.03782
Correct/Total: 64/129
Validation Loss:  1.92527
Validat