In [1]:
import os
import csv

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torchvision import datasets, transforms
from tqdm import tqdm
from torchsummary import summary

from data.dataset import SegmentationDataset
from model import someNet
from modules.Loss import WeightedCrossEntropyLoss
from train import train
from modules.Score import Score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def append_to_csv(file_path, data):
    with open(file_path, 'a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(data)

In [3]:
train_image_dir = "train/image/"
train_label_dir = "train/mask/"
log_path = "train_log.csv"
batch_size = 10
epochs=500
lr = 0.001

def get_file_names(folder_path):
    file_names = []
    for file_name in os.listdir(folder_path):
        if os.path.isfile(os.path.join(folder_path, file_name)):
            file_names.append(folder_path + file_name)
    return file_names

train_image_path = get_file_names(train_image_dir)
train_label_path = get_file_names(train_label_dir)

train_transforms = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor(),
])

In [4]:
train_raw_dataset = SegmentationDataset(train_image_path, train_label_path, train_transforms)

In [5]:
model = someNet()
# summary(model,(1,320,640))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
if os.path.exists(log_path):
    os.remove(log_path)
header = ["train_loss", "valid_loss", "score", "dice", "iou", "lr"]
append_to_csv(log_path, header)

In [6]:
valid_size = 100
train_size = len(train_raw_dataset) - valid_size
train_dataset, valid_dataset = random_split(train_raw_dataset, [train_size, valid_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
full_loader = DataLoader(train_raw_dataset, batch_size=batch_size, shuffle=True)

In [7]:
optimizer = optim.Adam(model.parameters(), lr = lr)
criterion = nn.BCEWithLogitsLoss()
reduce_schedule = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10,
        verbose=False, threshold=1e-4, threshold_mode='rel',
        cooldown=0, min_lr=0, eps=1e-8)
scorer = Score()
try:
    os.makedirs("./outputs")
except:
    pass
img = valid_dataset[0][0].to(device)
img_display = (img*255).squeeze(0).int().cpu().numpy()
cv2.imwrite("./outputs/img.jpg", img_display)

for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    for inputs, labels in tqdm(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        torch.sum(outputs == 0)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)
    reduce_schedule.step(loss)
    print(f"--- Epoch {epoch+1}/{epochs}: Train loss: {train_loss:.4f}", end = "")
    model.eval()
    valid_loss = 0.0
    for inputs, labels in valid_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        valid_loss += loss.item()
    valid_loss /= len(valid_loader)
    epoch_score = 0.0
    epoch_dice = 0.0
    epoch_iou = 0.0
    # epoch_hdict = 0.0
    for inputs, labels in full_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            outputs = model(inputs)
            score, dice, iou = scorer(outputs, labels)
        epoch_score += score.item()
        epoch_dice += dice.item()
        epoch_iou += iou.item()
        # epoch_hdict += hdict.item()
    epoch_score /= len(full_loader)
    epoch_dice /= len(full_loader)
    epoch_iou /= len(full_loader)
    # epoch_hdict /= len(full_loader)
    print(f", valid loss: {valid_loss:.4f}, score: {epoch_score:.4f}")
    current_lr = optimizer.param_groups[0]['lr']
    outdata = [train_loss, valid_loss, epoch_score, epoch_dice, epoch_iou, current_lr]
    append_to_csv(log_path, outdata)
    
    label = valid_dataset[0][1].to(device)
    label_img = (label>0).int().cpu().numpy()
    with torch.no_grad():
        output = model(img.unsqueeze(0))
        pred_img = (output>0).int().cpu().numpy().squeeze(0)
    blank = np.zeros_like(pred_img)
    out = np.concatenate((pred_img,label_img,blank), axis = 0).transpose(1, 2, 0)*255
    cv2.imwrite("./outputs/{}.jpg".format(epoch+1), out)
    try:
        os.makedirs("./state_dict")
    except:
        pass
    
    if ((epoch+1)%10 == 0):
        torch.save(model.state_dict(), "./state_dict/{}.pt".format(epoch+1))


100%|██████████| 190/190 [01:18<00:00,  2.42it/s]


--- Epoch 1/500: Train loss: 0.6585, valid loss: 0.6327, score: 0.5356


100%|██████████| 190/190 [01:16<00:00,  2.47it/s]


--- Epoch 2/500: Train loss: 0.6195, valid loss: 0.6207, score: 0.5808


100%|██████████| 190/190 [01:17<00:00,  2.45it/s]


--- Epoch 3/500: Train loss: 0.6119, valid loss: 0.6156, score: 0.5987


100%|██████████| 190/190 [01:17<00:00,  2.46it/s]


--- Epoch 4/500: Train loss: 0.6078, valid loss: 0.6165, score: 0.6040


100%|██████████| 190/190 [01:21<00:00,  2.34it/s]


--- Epoch 5/500: Train loss: 0.6023, valid loss: 0.6077, score: 0.6161


100%|██████████| 190/190 [01:23<00:00,  2.26it/s]


--- Epoch 6/500: Train loss: 0.5989, valid loss: 0.6049, score: 0.6191


 26%|██▌       | 49/190 [00:20<00:59,  2.38it/s]


KeyboardInterrupt: 