In [2]:
import torch 
import torchvision 
import numpy as np

import os
import pandas as pd
import albumentations as A
import albumentations.pytorch
import cv2

from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
# import torch.nn.functional as F

import sys
sys.path.append('../../')
from dataset import val_transforms, CDataset

In [3]:
import easydict 
args = easydict.EasyDict({ "batch_size": 256, 
                          "epochs": 600, 
                          "data": 0, 
                          'lr':0.1,
                         'momentum':0.9,
                         'weight_decay':1e-4,
                         'start_epoch':0,
                         'gpu':5,
                          'workers':8,
                         'print_freq':1000,
                         'saved_dir':'../../trained_models/stain_only/checkpoint.pt'})

In [4]:
# make saved dir
from pathlib import Path
path = Path(args.saved_dir.split('checkpoint')[0])
path.mkdir(parents=True, exist_ok=True)

In [5]:
# os.environ["CUDA_VISIBLE_DEVICES"]= "7"
# device = torch.device("cuda")

# print('Device:', device)
# print('Current cuda device:', torch.cuda.current_device())
# print('Count of using GPUs:', torch.cuda.device_count())

In [6]:
ngpus_per_node = torch.cuda.device_count()
print(ngpus_per_node)
# device = torch.device('cpu')
# device = torch.device('cuda')
GPU_NUM = args.gpu # 원하는 GPU 번호 입력
device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')
torch.cuda.set_device(device)
print(device)

6
cuda:5


In [8]:
IMAGE_SIZE = 256
from augment import HEColorAugment
train_transforms = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE, p=1),
    A.OneOf([
        A.HorizontalFlip(p=.8),
        A.VerticalFlip(p=.8),
        A.RandomRotate90(p=.8)]
    ),
    A.RandomResizedCrop(height=IMAGE_SIZE,width=IMAGE_SIZE,scale=[0.95,1.05],ratio=[0.95,1.05],p=0.5),
#     A.transforms.ColorJitter(brightness=0.05, contrast=0.2, saturation=0.2, hue=0.02, p=.8),
#     HEColorAugment(sigma1=0.1, sigma2=3, theta=18, p=.8),
], p=1.0) 

In [9]:
train_df = pd.read_csv('../../dataframe/train_df.csv')
train_dataset = CDataset(train_df, transform=train_transforms)

val_df = pd.read_csv('../../dataframe/val_df.csv')
val_dataset = CDataset(val_df, transform=val_transforms)  

test_df = pd.read_csv('../../dataframe/test_df.csv')
test_dataset = CDataset(test_df, transform=val_transforms)  

In [10]:
image, label, path = next(iter(train_dataset))

In [11]:
train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
                          shuffle=True, num_workers=args.workers)

val_loader = DataLoader(val_dataset, batch_size=args.batch_size,
                          shuffle=True, num_workers=args.workers)

test_loader = DataLoader(test_dataset, batch_size=args.batch_size,
                          shuffle=True, num_workers=args.workers)


In [12]:
images, labels, paths = next(iter(train_loader))
images.shape

torch.Size([256, 3, 256, 256])

In [13]:
import torchvision.models as models
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(512, 3)
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
# optimizer = torch.optim.Adam(params, lr=0.0001)
optimizer = torch.optim.SGD(
       params, lr=0.1, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 30, 40, 50], 
                                                    gamma=0.2)
# criterion = nn.CrossEntropyLoss(label_smoothing=0.0)
criterion = nn.CrossEntropyLoss().to(device)

In [None]:
from main import *

best_acc1 = 0
acc1 = 0
train_loss = []
val_acc = []

for epoch in range(60):
    losses = train(train_loader, model, criterion, optimizer, epoch, args)
    scheduler.step()

    # evaluate after every epoch
    acc1 = validate(val_loader, model, criterion, args)    
    
    # remember best acc@1 and save checkpoint
    is_best = acc1 > best_acc1
    best_acc1 = max(acc1, best_acc1) 
    
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'best_acc1': best_acc1,
        'optimizer' : optimizer.state_dict(),
        'scheduler' : scheduler.state_dict(),
    }, is_best, filename=args.saved_dir)   
    

In [None]:
checkpoint = torch.load(args.saved_dir)
model.load_state_dict(checkpoint['state_dict'])
acc1 = validate(test_loader, model, criterion, args)    