In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

# PreProcessing

In [2]:
from PIL import Image
import numpy as np
from pathlib import Path
import os

In [3]:
def cut_image(file_path, save_path, name,  i):
    im = Image.open(file_path)
    imarray = np.array(im)
    im_h, im_w = imarray.shape[:2]
    block_h, block_w = 224, 224
    
    for r, row in enumerate(np.arange(im_h - block_h +1, step = block_h)):
        for c, col in enumerate(np.arange(im_w - block_w +1, step = block_w)):
            im1 = imarray[row:row+block_h, col:col+block_w, :]
            im1 = Image.fromarray(im1)
            im1.save(save_path + "/" + name + f"{i}" + f"{c}_{r}" +".png")

In [4]:
data_dir_path = Path("/content/drive/My Drive/develop/미생물/data/raw/")

for raw_path in data_dir_path.glob("*"):
    path = Path(raw_path)

    # print(path)
    
    microbe = str(path).split('/').pop()

    # print(microbe)

    save_path = "/content/data/edit/"+microbe
    
    # print(save_path)

    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    files = [x for x in path.glob("**/*.tif") if x.is_file()]

    for i, file in enumerate(files):
        cut_image(file, save_path, microbe, i)
        if i > 3:
            break
    print("------", raw_path ,"------")

------ /content/drive/My Drive/develop/미생물/data/raw/Bifidobacterium.spp ------
------ /content/drive/My Drive/develop/미생물/data/raw/Clostridium.perfringens ------
------ /content/drive/My Drive/develop/미생물/data/raw/Escherichia.coli ------
------ /content/drive/My Drive/develop/미생물/data/raw/Acinetobacter.baumanii ------
------ /content/drive/My Drive/develop/미생물/data/raw/Bacteroides.fragilis ------
------ /content/drive/My Drive/develop/미생물/data/raw/Fusobacterium ------
------ /content/drive/My Drive/develop/미생물/data/raw/Enterococcus.faecalis ------
------ /content/drive/My Drive/develop/미생물/data/raw/Candida.albicans ------
------ /content/drive/My Drive/develop/미생물/data/raw/Enterococcus.faecium ------
------ /content/drive/My Drive/develop/미생물/data/raw/Actinomyces.israeli ------
------ /content/drive/My Drive/develop/미생물/data/raw/Lactobacillus.crispatus ------
------ /content/drive/My Drive/develop/미생물/data/raw/Listeria.monocyt

# pytorch

In [None]:
!nvidia-smi

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

if device == 'cuda':
    torch.cuda.set_device(0)
    print(torch.cuda.current_device())

0


In [None]:
# root = "/content/drive/My Drive/develop/미생물/data/raw/"
root = "/content/data/edit/"

dataset = torchvision.datasets.ImageFolder(root=root,
                                           transform=transforms.Compose([
                                               transforms.ToTensor(),
                                               transforms.Normalize((0.5, 0.5, 0.5),  
                                                                    (0.5, 0.5, 0.5)),
                                           ]))

In [None]:
len(dataset)

4752

In [None]:
dataset.class_to_idx

{'Acinetobacter.baumanii': 0,
 'Actinomyces.israeli': 1,
 'Bacteroides.fragilis': 2,
 'Bifidobacterium.spp': 3,
 'Candida.albicans': 4,
 'Clostridium.perfringens': 5,
 'Enterococcus.faecalis': 6,
 'Enterococcus.faecium': 7,
 'Escherichia.coli': 8,
 'Fusobacterium': 9,
 'Lactobacillus.crispatus': 10,
 'Listeria.monocytogenes': 11,
 'Micrococcus.spp': 12,
 'Veionella': 13,
 'models': 14}

In [None]:
train, val, test = torch.utils.data.random_split(dataset, [12000, 2000, 2200])
# train, val, test = torch.utils.data.random_split(dataset, [200, 50, 50])

In [None]:
bs = 64
train_dataloader = torch.utils.data.DataLoader(train,
                                         batch_size=bs,
                                         shuffle=True,
                                         num_workers=16)
val_dataloader = torch.utils.data.DataLoader(val,
                                         batch_size=bs,
                                         shuffle=True,
                                         num_workers=16)
test_dataloader = torch.utils.data.DataLoader(test,
                                         batch_size=bs,
                                         shuffle=True,
                                         num_workers=16)

In [None]:
model = models.resnet18(pretrained=False)

In [None]:
model.fc = nn.Linear(512, len(dataset.classes), bias=True)

In [None]:
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

In [None]:
def train_step(inputs, targets):
    optimizer.zero_grad()
    
    inputs, targets = inputs.to(device), targets.to(device)
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    
    batch_loss = loss.item()
    _, predictions = outputs.max(1)  # return values, indices
    correct = predictions.eq(targets).sum().item()
    
    return batch_loss, correct


def test_step(inputs, targets):
    with torch.no_grad():
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        batch_loss = loss.item()
        _, predictions = outputs.max(1)
        correct = predictions.eq(targets).sum().item()
        
        return batch_loss, correct

In [None]:
best_acc = 0.

In [None]:
EPOCHS = 10

for epoch in range(EPOCHS):
    
    train_loss = 0.
    train_total = 0
    train_correct = 0
    test_loss = 0.
    test_total = 0
    test_correct = 0
    
    for batch_idx, (inputs, targets) in enumerate(train_dataloader):
        batch_loss, correct = train_step(inputs, targets)
        train_loss += batch_loss
        train_total += targets.size(0)
        train_correct += correct
        
    for batch_idx, (inputs, targets) in enumerate(test_dataloader):
        batch_loss, correct = test_step(inputs, targets)
        test_loss += batch_loss
        test_total += targets.size(0)
        test_correct += correct
        
    test_accuracy = (test_correct/test_total)*100
        
    template = 'Epoch {}, Loss: {:.6f}, Accuracy: {:.3f}%, Test Loss: {:.6f}, Test Accuracy: {:.3f}%'
    print(template.format(epoch+1,
                          train_loss,
                          (train_correct/train_total)*100,
                          test_loss,
                          test_accuracy))
    
    if test_accuracy > best_acc:
        print("new best acc!")
        best_acc = test_accuracy
        torch.save(model.state_dict(), "/content/drive/My Drive/develop/미생물/resnet18.model")

Epoch 1, Loss: 6.843140, Accuracy: 98.775%, Test Loss: 1.218805, Test Accuracy: 98.864%
new best acc!
Epoch 2, Loss: 6.984822, Accuracy: 98.758%, Test Loss: 1.920674, Test Accuracy: 98.409%
Epoch 3, Loss: 6.884646, Accuracy: 98.708%, Test Loss: 1.868733, Test Accuracy: 98.455%
Epoch 4, Loss: 5.961758, Accuracy: 98.908%, Test Loss: 2.707401, Test Accuracy: 97.818%
Epoch 5, Loss: 15.424355, Accuracy: 97.333%, Test Loss: 1.990089, Test Accuracy: 98.227%
Epoch 6, Loss: 7.891608, Accuracy: 98.650%, Test Loss: 3.442610, Test Accuracy: 97.045%
Epoch 7, Loss: 6.083110, Accuracy: 99.008%, Test Loss: 1.032016, Test Accuracy: 98.773%
Epoch 8, Loss: 4.871987, Accuracy: 99.208%, Test Loss: 1.049752, Test Accuracy: 99.000%
new best acc!
Epoch 9, Loss: 3.007817, Accuracy: 99.508%, Test Loss: 0.856295, Test Accuracy: 99.409%
new best acc!
Epoch 10, Loss: 3.578362, Accuracy: 99.342%, Test Loss: 1.116382, Test Accuracy: 99.182%
