In [None]:
from PIL import Image
import csv
import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import time, os, json

class_names = ['Normal', 'Almost Clear', 'Mild', 'Moderate', 'Severe']

In [None]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, mode = 'Training', transform = None):
        self.root_dir = root_dir
        self.img_path = glob.glob(root_dir + '/*/crop/*.jpg')
        self.transform = transform
        self.mode = mode

    def __len__(self):
        return len(self.img_path)

    def __getitem__(self, idx):
        img = Image.open(self.img_path[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        if self.mode == 'Test':
            return img, self.img_path[idx]
        else:
            dir_name = os.path.dirname(self.img_path[idx])
            file_name = os.path.basename(self.img_path[idx])
            json_full_path = os.path.join(os.path.join(dir_name[:-4], 'metadata'), file_name[:-4] + '.json')
            data = json.load(open(json_full_path))
            if 'iga_grade' in data['annotations'][0]['clinical_info']:
                grade = class_names.index(data['annotations'][0]['clinical_info']['iga_grade'])
            else:
                grade = 0
            return img, grade

In [None]:
train_transforms = transforms.Compose([
        transforms.Resize(299),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

test_transforms = transforms.Compose([
        transforms.Resize(299),
        transforms.RandomResizedCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])


In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs, train_loader, valid_loader):
    model.train()
    train_loss = 0
    train_acc = 0
    train_correct = 0
    train_losses = []
    train_accuracies = []
    valid_losses = []
    valid_accuracies = []

    for epoch in range(num_epochs):
        start = time.time()
        for train_x, train_y in train_loader:
            model.train()
            train_x, train_y = train_x.to(device), train_y.to(device)
            optimizer.zero_grad()
            pred = model(train_x)
            _, preds = torch.max(pred, 1)
            loss = criterion(pred, train_y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            train_correct += torch.sum(preds == train_y)
            scheduler.step()

        valid_loss = 0
        valid_acc = 0
        valid_correct = 0

        for valid_x, valid_y in valid_loader:
            with torch.no_grad():
                model.eval()
                valid_x, valid_y = valid_x.to(device), valid_y.to(device)
                pred = model(valid_x)
                loss = criterion(pred, valid_y)
            valid_loss += loss.item()
            pred = model(train_x)
            _, preds = torch.max(pred, 1)
            valid_correct += torch.sum(preds == train_y)
        train_acc = train_correct/len(train_loader.dataset)
        valid_acc = valid_correct/len(valid_loader.dataset)
        print(f'{time.time() - start:.3f}sec : [Epoch {epoch+1}/{num_epochs} -> \
              train loss: {train_loss/len(train_loader):.4f}, train acc: {train_acc*100:.3f}%/ \
              valid loss: {valid_loss/len(valid_loader):.4f}, valid acc: {valid_acc*100:.3f}%')
        
        train_losses.append(train_loss/len(train_loader))
        train_accuracies.append(train_acc)
        valid_losses.append(valid_loss/len(valid_loader))
        valid_accuracies.append(valid_acc)

        train_loss = 0
        train_acc = 0
        train_correct = 0
    
    return model

In [None]:
data_dir = 'dataset'

train_dir = os.path.join(data_dir, 'train')
valid_dir = os.path.join(data_dir, 'validation')
test_dir = os.path.join(data_dir, 'test')
print(train_dir)

train_dataset = CustomDataset(train_dir, transform = train_transforms)
valid_dataset = CustomDataset(valid_dir, transform = test_transforms)
test_dataset = CustomDataset(test_dir, mode = 'Test', transform = test_transforms)
print("train_dataset = ", len(train_dataset))
print("valid_dataset = ", len(valid_dataset))
print("test_dataset = ", len(test_dataset))

train_loader = DataLoader(train_dataset, batch_size = 32, shuffle = True)
valid_loader = DataLoader(valid_dataset, batch_size = 32, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = 1, shuffle = False)
print("train_loader = ", len(train_loader))
print("valid_loader = ", len(valid_loader))
print("test_loader = ", len(test_loader))

# GPU 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = models.resnet18(pretrained = True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, len(class_names))

model_ft = model_ft.to(device)
print("Model loaded")

criterion = nn.CrossEntropyLoss()

optimizer_ft = optim.AdamW(model_ft.parameters(), lr = 0.0001)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)

num_epochs = 50
print("Training start")


#     # inference code
#     f = open('배류나이 배류배류_test_result.csv', 'w', encoding ='utf-8', newline = '')
#     wr = csv.writer(f)
#     wr.writerow(['case', 'Predicted Severity', 'Inference time(ms)'])
#     with torch.no_grad():
#         model_ft.eval()
#         correct = 0
#         losses = 0

#         for img, files in test_loader:
#             img = img.to(device)
#             start = time.time()
#             pred = model_ft(img)
#             _, preds = torch.max(pred, 1)
#             end = time.time()
#             preds = preds.cpu().numpy()[0]

#             wr.writerow([os.path.basename(files[0][:-4], class_names[preds], (end - start) * 1000)])

#     f.close()

In [None]:
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs, train_loader, valid_loader)