In [6]:
from pathlib import Path

import torch
import torchvision.transforms as transforms
import timm

import torch.nn as nn

import torch.optim as optim

from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import random_split
from tqdm import tqdm

from tooth_crop_dataset import ToothCropClassDataset
from utils.vit import train, test

log_dir = Path('runs') / 'mobile_net_v2_th09'
writer = SummaryWriter(log_dir=log_dir)


In [7]:
# Data
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Preprocess
transform = transforms.Compose([
    transforms.ToTensor(),
    # (lambda image: padding_to_size(image, 224)),
    transforms.Resize(size=(224, 224)),
    transforms.Normalize(mean=0.5, std=0.5),
])
target_transform = transforms.Compose([
    (lambda y: torch.Tensor(y)),
])

# Hyperparameter
epoch_num = 240
batch_size = 16
num_workers = 0
train_test_split = 0.8

In [8]:
dataset = ToothCropClassDataset(root='../preprocess', transform=transform, target_transform=target_transform)

dataset_size = len(dataset)
train_size = int(train_test_split * dataset_size)
test_size = dataset_size - train_size

train_set, test_set = random_split(dataset, [train_size, test_size])

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                           shuffle=True, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,
                                          shuffle=True, num_workers=num_workers)

classes = dataset.mlb.classes_

train_label_count = torch.zeros(len(classes))
for x, y in train_loader:
    train_label_count += y.sum(axis=0)

test_label_count = torch.zeros(len(classes))
for x, y in test_loader:
    test_label_count += y.sum(axis=0)

print(classes)
print(train_label_count)
print(test_label_count)



Total data in 1041
['R.R' 'caries' 'crown' 'endo' 'filling' 'post']
tensor([ 23.,  53., 207., 216., 464., 134.])
tensor([  5.,  16.,  55.,  61., 107.,  44.])


In [9]:
model = timm.create_model('mobilenetv2_100', num_classes=6, pretrained=True)
# model = timm.create_model('swin_base_patch4_window7_224', num_classes=4, pretrained=True)
model.to(device)

pos_weight = torch.tensor([1, 1, 1, 1, 1, 1]).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
SGD_optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


In [10]:
for t in tqdm(range(epoch_num)):
    # print(f"Epoch {t + 1}\n-------------------------------")
    train(train_loader, model, criterion, SGD_optimizer, writer=writer, epoch=t, device=device)
    test(test_loader, model, criterion, len(classes), device=device, writer=writer, epoch=t, classes=classes, threshold=0.9)

writer.close()
print("Done!")

print('Finished Training')
# save your improved network
torch.save(model.state_dict(), log_dir / 'trained-net.pt')

100%|██████████| 240/240 [12:22<00:00,  3.09s/it]

Done!
Finished Training



