In [None]:
import os
import math
import numpy
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import torch.nn.functional as F

import gc

from torch.optim.lr_scheduler import StepLR

import torchvision.transforms as T
import torchvision.transforms.functional as TF

from torchvision.models.vision_transformer import VisionTransformer

from pytorch_pretrained_vit import ViT

In [None]:
torch.manual_seed(20)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
# Directory Names
dir_training = '/kaggle/input/icdas-70x70/icdas_preprocessed/training'
dir_testing = '/kaggle/input/icdas-70x70/icdas_preprocessed/testing'

In [None]:
class ToothDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.dataset_path = img_dir
        self.transform = transform

    def __len__(self):
        return len(os.listdir(self.dataset_path))

    def __getitem__(self, idx):
        if idx  >= len(os.listdir(self.dataset_path)):
            print("No datafile/image at index : "+ str(idx))
            return None
        npy_filename = os.listdir(self.dataset_path)[idx]
        label = int(npy_filename[3] == 'B')
        
        numpy_arr = numpy.load(self.dataset_path + '/' + npy_filename)
        
        for i in range(numpy_arr.shape[0]-70): numpy_arr = numpy.delete(numpy_arr, [0], axis=0)
            
        numpy_arr = numpy_arr.reshape(1, 70, 70, 70)
        tensor_arr = torch.from_numpy(numpy_arr).to(torch.float32)

        del numpy_arr 
        gc.collect()

        if self.transform: tensor_arr = self.transform(tensor_arr) # Apply transformations

        return tensor_arr.to(torch.float32), torch.LongTensor([label])

In [None]:
training_data = ToothDataset(img_dir=dir_training, transform=None)
validation_data = ToothDataset(img_dir=dir_testing, transform=None)

In [None]:

class Basic3DCNN(nn.Module):
    def __init__(self, num_classes=4):
        super(Basic3DCNN, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv3d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1)
        # Max pooling layers
        self.pool = nn.MaxPool3d(2, 2)
        # Fully connected layers
        self.fc1 = nn.Linear(64 * 8 * 8 * 8, 256)  # Adjust input size based on pooling
        self.dropout = nn.Dropout(0.2)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        # Convolutional layers with ReLU activation and max pooling
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        # Flatten the output
        x = x.view(-1, 64 * 8 * 8 * 8)  # Adjust output size based on pooling
        # Fully connected layers with ReLU activation
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


In [None]:
# model = NeuralNetwork().to(device)
model = Basic3DCNN(num_classes=2).to(device)

In [None]:
datax = training_data[0][0].reshape(1,1,70,70,70).to(device)

In [None]:
model(datax)

In [None]:
# Hyperparameters
epochs = 500
batch_size = 2
learning_rate = 1e-3
weight_decay = 0.0000000001
momentum=0.9

In [None]:
loss_function=nn.CrossEntropyLoss()
optimizer = torch.optim.Adam( model.parameters()  ,lr=learning_rate)
# optimizer=torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

In [None]:
training_data_loader = DataLoader(training_data, batch_size, shuffle = True)
validation_data_loader = DataLoader(validation_data, batch_size, shuffle = False)

In [None]:
def train(dataloader, model, loss_fn, optimizer):
#     torch.cuda.empty_cache()
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        
        # Compute prediction error
        pred = model(X)
        
        loss = loss_fn(pred, y.squeeze())

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch%5==0:
          # Print
          loss, current = loss.item(), batch * len(X)
          print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [None]:
validation_accuracy = []
def validation(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y.squeeze()).item()
            correct += (torch.argmax(pred, dim=1) == y.squeeze()).sum().item()
            X.cpu()
            y.cpu()
    test_loss /= num_batches
    correct /= size
    validation_accuracy.append(correct*100)
    # Print
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return correct*100

In [None]:

# Define the directory to save the model
save_dir = "/kaggle/working/saved_models"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

best_accuracy = 0.0  # Initialize best validation accuracy
best_epoch = 0  # Initialize the epoch with the best validation accuracy

# Training
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(training_data_loader, model, loss_function, optimizer)
    accuracy = validation(validation_data_loader, model, loss_function)
    
    # Save the model for every 50th epoch
    if (t + 1) % 100 == 0:
        save_path = os.path.join(save_dir, f"model_epoch_{t+1}_accuracy_{accuracy:.2f}.pt")
        torch.save(model.state_dict(), save_path)
        print(f"Model saved at epoch {t+1}")
    
    # Check if the current accuracy is better than the best accuracy
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        best_epoch = t + 1
        best_model_path = os.path.join(save_dir, f"best_model_epoch_{best_epoch}_accuracy_{best_accuracy:.2f}.pt")
        torch.save(model.state_dict(), best_model_path)
        print(f"Best model saved with accuracy: {best_accuracy:.2f} at epoch {best_epoch}")

print("Training done!")
print(f"Best validation accuracy: {best_accuracy:.2f} at epoch {best_epoch}")