In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision.datasets import ImageFolder

import os
import sys
import time
import math
import argparse
import datetime
import logging
import pandas as pd
from PIL import Image
from pathlib import Path


In [None]:
def initLogger(filename, loggerName):
    timestamp = datetime.datetime.utcnow().strftime("%Y%m%d_%H-%M-%S")
    logging.basicConfig(
        level=logging.INFO,
        format="[%(asctime)s] - %(message)s",
        handlers=[
            logging.FileHandler(filename=filename),
            logging.StreamHandler(sys.stdout),
        ],
    )
    logger = logging.getLogger(loggerName)
    return logger


In [None]:
def getAcc(pred, label):
    acc = 0
    eq = pred == label
    for i in range(len(eq)):
        if eq[i]:
            acc += 1
    return acc


In [None]:
class MyModel(nn.Module):
    def __init__(self, model):
        super(MyModel, self).__init__()
        self.model = model
        self.model.fc = nn.Linear(self.model.fc.in_features, 3)

    def forward(self, input):
        return self.model(input)


In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--runFolderPath", type=str, default="./runs")
parser.add_argument("--epochSize", type=int, default=300)
parser.add_argument("--batchSize", type=int, default=128)
parser.add_argument("--modelName", type=str, default="resnet18")
parser.add_argument("--preTrained", type=bool, default=True)
parser.add_argument("--freeze", type=bool, default=False)
parser.add_argument("--trainPath", type=str, default="./obmorbinet/train")
parser.add_argument("--validatePath", type=str, default="./obmorbinet/validate")
parser.add_argument("--testPath", type=str, default="./obmorbinet/test")
parser.add_argument("--learningRate", type=float, default=0.0002)
args = parser.parse_args(args=[])

if not os.path.exists(args.runFolderPath):
    os.mkdir(args.runFolderPath)

subModel = torch.hub.load(
    "pytorch/vision:v0.10.0", args.modelName, pretrained=args.preTrained
)

runNum = (
    len(
        [
            i
            for i in os.listdir(os.path.join(args.runFolderPath))
            if i.startswith("train")
        ]
    )
    + 1
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

os.mkdir("./runs/train-{}".format(runNum))

myLogger = initLogger(
    "./runs/train-{}/OBMORBINet-{}.log".format(runNum, time.strftime("%Y-%m-%d-%H-%M-%S")),
    "myLogger",
)

myLogger.info("Start {} training, train num: {}".format(args.modelName, runNum))

myLogger.info("args:\n{}".format(args))


In [None]:
trainTransform = T.Compose(
    [
        T.RandomHorizontalFlip(0.5),
        T.RandomVerticalFlip(0.5),
        T.RandomRotation(15),
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

validateTransform = T.Compose(
    [
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

testTransform = T.Compose(
    [
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

trainDataset = ImageFolder(args.trainPath, trainTransform)
validateDataset = ImageFolder(args.validatePath, validateTransform)
testDataset = ImageFolder(args.testPath, testTransform)


trainDataloder = DataLoader(trainDataset, batch_size=args.batchSize, shuffle=True)
validateDataloder = DataLoader(validateDataset, batch_size=args.batchSize, shuffle=True)
testDataloder = DataLoader(testDataset, batch_size=args.batchSize, shuffle=True)

train_each_epoch_iter = math.ceil(len(trainDataset) / args.batchSize)
validate_each_epoch_iter = math.ceil(len(validateDataset) / args.batchSize)
test_each_epoch_iter = math.ceil(len(testDataset) / args.batchSize)


In [7]:
if args.freeze:
    for param in subModel.parameters():
        param.requires_grad = False

model = MyModel(subModel).to(device)

lossFunction = nn.CrossEntropyLoss()

opt = optim.Adam(model.parameters(), lr=args.learningRate, weight_decay=5e-4)

myLogger.info(opt)

trainIteration, validateIteration, testIteration = 1, 1, 1

for epoch in range(args.epochSize):
    trainLossValue, validateLossValue, testLossValue = 0, 0, 0
    trainAccValue, validateAccValue, testAccValue = 0, 0, 0
    model.train()
    for step, (image, label) in enumerate(trainDataloder):
        image = image.to(device)
        label = label.to(device)
        opt.zero_grad()
        out = model(image)
        loss = lossFunction(out, label)
        loss.backward()
        opt.step()
        pred = out.argmax(1)
        trainLossValue += loss.item()
        trainAccValue += getAcc(pred, label.data)

    model.eval()
    for step, (image, label) in enumerate(validateDataloder):
        with torch.no_grad():
            image = image.to(device)
            label = label.to(device)
            out = model(image)
            pred = out.argmax(1)
            loss = lossFunction(out, label)
            validateLossValue += loss.item()
            validateAccValue += getAcc(pred, label.data)

    model.eval()
    for step, (image, label) in enumerate(testDataloder):
        with torch.no_grad():
            image = image.to(device)
            label = label.to(device)
            out = model(image)
            pred = out.argmax(1)
            loss = lossFunction(out, label)
            testLossValue += loss.item()
            testAccValue += getAcc(pred, label.data)

    torch.save(
        model,
        "{}/train-{}/model_{}.pth".format(args.runFolderPath, runNum, epoch),
    )

    myLogger.info(
        [
            epoch,
            trainLossValue / train_each_epoch_iter,
            validateLossValue / validate_each_epoch_iter,
            testLossValue / test_each_epoch_iter,
            trainAccValue / len(trainDataset),
            validateAccValue / len(validateDataset),
            testAccValue / len(testDataset),
        ]
    )
