In [1]:
import torch
import torchvision.transforms as transforms
import timm

import torch.nn as nn

import torch.optim as optim

from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import random_split

from tooth_crop_dataset import ToothCropClassDataset
from utils.vit import train, test

writer = SummaryWriter()


In [2]:
# Data
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Preprocess
transform = transforms.Compose([
    transforms.ToTensor(),
    # (lambda image: padding_to_size(image, 224)),
    transforms.Resize(size=(224, 224)),
    transforms.Normalize(mean=0.5, std=0.5),
])
target_transform = transforms.Compose([
    (lambda y: torch.Tensor(y)),
])

# Hyperparameter
epoch_num = 120
batch_size = 16
num_workers = 0
train_test_split = 0.8

In [3]:
dataset = ToothCropClassDataset(root='../preprocess', transform=transform, target_transform=target_transform)

dataset_size = len(dataset)
train_size = int(train_test_split * dataset_size)
test_size = dataset_size - train_size

train_set, test_set = random_split(dataset, [train_size, test_size])

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                           shuffle=True, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,
                                          shuffle=True, num_workers=num_workers)

classes = dataset.mlb.classes_

train_label_count = torch.zeros(len(classes))
for x, y in train_loader:
    train_label_count += y.sum(axis=0)

test_label_count = torch.zeros(len(classes))
for x, y in test_loader:
    test_label_count += y.sum(axis=0)

print(classes)
print(train_label_count)
print(test_label_count)



Total data in 1041
['R.R' 'caries' 'crown' 'endo' 'filling' 'post']
tensor([ 25.,  54., 208., 221., 458., 140.])
tensor([  3.,  15.,  54.,  56., 113.,  38.])


In [12]:
model = timm.create_model('vit_base_patch16_224', num_classes=6, pretrained=True)
# model = timm.create_model('swin_base_patch4_window7_224', num_classes=4, pretrained=True)
model.to(device)

pos_weight = torch.tensor([1, 2, 1, 1, 1, 1]).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
SGD_optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


In [None]:
for t in range(epoch_num):
    print(f"Epoch {t + 1}\n-------------------------------")
    train(train_loader, model, criterion, SGD_optimizer, writer=writer, epoch=t, device=device)
    test(test_loader, model, criterion, len(classes), device=device, writer=writer, epoch=t, classes=classes)

writer.close()
print("Done!")

print('Finished Training')
# save your improved network
torch.save(model.state_dict(), './pretrained-net.pt')


Epoch 1
-------------------------------


Exception in thread Thread-6:
Traceback (most recent call last):
  File "D:\Users\douli\anaconda3\envs\dentist-CV-main\lib\threading.py", line 932, in _bootstrap_inner
    self.run()
  File "D:\Users\douli\anaconda3\envs\dentist-CV-main\lib\site-packages\tensorboard\summary\writer\event_file_writer.py", line 233, in run
    self._record_writer.write(data)
  File "D:\Users\douli\anaconda3\envs\dentist-CV-main\lib\site-packages\tensorboard\summary\writer\record_writer.py", line 40, in write
    self._writer.write(header + header_crc + data + footer_crc)
  File "D:\Users\douli\anaconda3\envs\dentist-CV-main\lib\site-packages\tensorboard\compat\tensorflow_stub\io\gfile.py", line 766, in write
    self.fs.append(self.filename, file_content, self.binary_mode)
  File "D:\Users\douli\anaconda3\envs\dentist-CV-main\lib\site-packages\tensorboard\compat\tensorflow_stub\io\gfile.py", line 160, in append
    self._write(filename, file_content, "ab" if binary_mode else "a")
  File "D:\Users\douli\an

loss: 0.710990  [    0/  832]
Test Error:
Accuracy: 91.1%, Avg loss: 0.667483
Precision: 81.7%, Sensitivity: 77.1% 

Class R.R:
Accuracy: 98.6%
Precision: nan%, Sensitivity: 0.0%
Class caries:
Accuracy: 85.2%
Precision: 21.4%, Sensitivity: 40.0%
Class crown:
Accuracy: 90.0%
Precision: 86.7%, Sensitivity: 72.2%
Class endo:
Accuracy: 93.8%
Precision: 85.2%, Sensitivity: 92.9%
Class filling:
Accuracy: 91.4%
Precision: 90.6%, Sensitivity: 93.8%
Class post:
Accuracy: 87.6%
Precision: 100.0%, Sensitivity: 31.6%
Epoch 2
-------------------------------
loss: 0.216443  [    0/  832]
Test Error:
Accuracy: 92.5%, Avg loss: 0.664215
Precision: 80.7%, Sensitivity: 87.1% 

Class R.R:
Accuracy: 98.6%
Precision: nan%, Sensitivity: 0.0%
Class caries:
Accuracy: 91.4%
Precision: 28.6%, Sensitivity: 13.3%
Class crown:
Accuracy: 91.4%
Precision: 76.5%, Sensitivity: 96.3%
Class endo:
Accuracy: 90.9%
Precision: 75.3%, Sensitivity: 98.2%
Class filling:
Accuracy: 90.9%
Precision: 92.0%, Sensitivity: 91.2%
Clas