In [2]:
!pip install pytorch_pretrained_vit



In [3]:
import os
import gc
import math
import numpy
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt
%matplotlib inline
from torchvision.models.vision_transformer import VisionTransformer
from pytorch_pretrained_vit import ViT

In [4]:
torch.manual_seed(20)

<torch._C.Generator at 0x7f29a85b61f0>

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [6]:
# Directory Names
dir_training = './drive/MyDrive/inlay_onlay_dataset/training'
dir_validation = './drive/MyDrive/inlay_onlay_dataset/validation'
dir_testing = './drive/MyDrive/inlay_onlay_dataset/testing'

In [7]:
class ToothDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.dataset_path = img_dir
        self.transform = transform

    def __len__(self):
        return len(os.listdir(self.dataset_path))

    def __getitem__(self, idx):
        if idx  >= len(os.listdir(self.dataset_path)):
            print("No datafile/image at index : "+ str(idx))
            return None
        npy_filename = os.listdir(self.dataset_path)[idx]
        label = int('onlay' in npy_filename)
        numpy_arr = numpy.load(self.dataset_path + '/' + npy_filename)
        for i in range(numpy_arr.shape[0]-70): numpy_arr = numpy.delete(numpy_arr, [0], axis=0)
        numpy_arr = numpy_arr.reshape(1, 70, 70, 70)
        tensor_arr = torch.from_numpy(numpy_arr).to(torch.float32)

        del numpy_arr
        gc.collect()

        if self.transform: tensor_arr = self.transform(tensor_arr) # Apply transformations

        return tensor_arr.to(torch.float32), torch.LongTensor([label])

In [10]:
training_data = ToothDataset(img_dir=dir_training, transform=None)
validation_data = ToothDataset(img_dir=dir_validation, transform=None)

In [11]:
# Download the pretrained weights
pretrained_vit = ViT('B_16_imagenet1k', pretrained=True)

Loaded pretrained weights.


In [12]:
# Freeze the ViT
for param in pretrained_vit.parameters():
    param.requires_grad = False

In [13]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()

        # The CNN layers with maxpool and ReLU
        self.features = nn.Sequential(
            nn.Conv3d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2),
            nn.Conv3d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2),
            nn.Conv3d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2),
            nn.Conv3d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2),
            nn.Conv3d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2),
            nn.Flatten(),
        )

        # Pretrained ViT
        self.vit = pretrained_vit

        # Fully connected layer
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(1000,512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512,2)
        )

    def forward(self, x):
        x = self.features(x)

        # Reshape the CNN output (output of size batch*4096) into batch*3x384x384
        x = x.repeat(1, 36, 1, 1, 1)
        x = x.view(-1, 1, 384, 384)
        x = x.repeat(1, 3, 1, 1)

        # Pass the reshaped CNN output into ViT
        x = self.vit(x)

        # Pass it into the fully connected layers
        x = self.fc(x)
        return x

In [14]:
model = NeuralNetwork().to(device)

In [15]:
# Checking a forward pass, if no error is detected, then eveything is okay
single_training_sample = training_data[0][0].reshape(1,1,70,70,70).to(device)
model(single_training_sample)

tensor([[0.0287, 0.2092]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [16]:
# Hyperparameters
epochs = 15
batch_size = 2
learning_rate = 1e-6
weight_decay = 0.1

In [17]:
loss_function=nn.CrossEntropyLoss()
optimizer = torch.optim.Adam( model.parameters()  ,lr=learning_rate)

In [18]:
training_data_loader = DataLoader(training_data, batch_size, shuffle = True)
validation_data_loader = DataLoader(validation_data, batch_size, shuffle = False)

In [19]:
# The training function
def train(dataloader, model, loss_fn, optimizer):
    torch.cuda.empty_cache()
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y.squeeze())

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch%5==0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


        del pred
        del loss

In [20]:
# Keep record of the validation accuracy
validation_accuracy = []

# The validation function
def validation(dataloader, model, loss_fn):
    global validation_accuracy
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y.squeeze()).item()
            correct += (torch.argmax(pred, dim=1) == y.squeeze()).sum().item()
            X.cpu()
            y.cpu()
    test_loss /= num_batches
    correct /= size

    # Keep record of the validation accuracy globally
    validation_accuracy.append(correct*100)
    # Print
    print(f"Validation Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    print(f"Current best validation accuracy: {max(validation_accuracy)}%")

In [21]:
# Training
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(training_data_loader, model, loss_function, optimizer)
    validation(validation_data_loader, model, loss_function)
print("Done!")

Epoch 1
-------------------------------
loss: 0.828423  [    0/   20]
loss: 0.637491  [   10/   20]
Validation Error: 
 Accuracy: 50.0%, Avg loss: 0.675662 

Current best validation accuracy: 50.0%
Epoch 2
-------------------------------
loss: 0.570743  [    0/   20]
loss: 0.796301  [   10/   20]
Validation Error: 
 Accuracy: 50.0%, Avg loss: 0.672706 

Current best validation accuracy: 50.0%
Epoch 3
-------------------------------
loss: 0.626618  [    0/   20]
loss: 0.644107  [   10/   20]
Validation Error: 
 Accuracy: 50.0%, Avg loss: 0.670998 

Current best validation accuracy: 50.0%
Epoch 4
-------------------------------
loss: 0.819908  [    0/   20]
loss: 0.573188  [   10/   20]
Validation Error: 
 Accuracy: 50.0%, Avg loss: 0.669826 

Current best validation accuracy: 50.0%
Epoch 5
-------------------------------
loss: 0.662102  [    0/   20]
loss: 0.886220  [   10/   20]
Validation Error: 
 Accuracy: 50.0%, Avg loss: 0.669390 

Current best validation accuracy: 50.0%
Epoch 6
--