In [1]:
import os
import csv

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torchvision import datasets, transforms
from tqdm import tqdm
from torchsummary import summary

from data.dataset import SegmentationDataset
from model import someNet
from modules.Loss import WeightedCrossEntropyLoss
from train import train
from modules.Score import Score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def append_to_csv(file_path, data):
    with open(file_path, 'a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(data)

In [3]:
train_image_dir = "train/image/"
train_label_dir = "train/mask/"
log_path = "train_log.csv"
batch_size = 10

def get_file_names(folder_path):
    file_names = []
    for file_name in os.listdir(folder_path):
        if os.path.isfile(os.path.join(folder_path, file_name)):
            file_names.append(folder_path + file_name)
    return file_names

train_image_path = get_file_names(train_image_dir)
train_label_path = get_file_names(train_label_dir)

train_transforms = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor(),
])

In [4]:
train_raw_dataset = SegmentationDataset(train_image_path, train_label_path, train_transforms)

In [5]:
model = someNet()
# summary(model,(1,320,640))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
if os.path.exists(log_path):
    os.remove(log_path)
header = ["train_loss", "valid_loss", "score", "dice", "iou", "h_dist"]
append_to_csv(log_path, header)

In [6]:
valid_size = 100
train_size = len(train_raw_dataset) - valid_size
train_dataset, valid_dataset = random_split(train_raw_dataset, [train_size, valid_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
full_loader = DataLoader(train_raw_dataset, batch_size=batch_size, shuffle=True)

In [7]:
def train(epochs=100, lr=0.001):
    optimizer = optim.Adam(model.parameters(), lr = lr)
    criterion = nn.CrossEntropyLoss()
    scorer = Score()
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for inputs, labels in tqdm(train_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            print(train_loss)
        train_loss /= len(train_loader)
        print(f"--- Epoch {epoch+1}/{epochs}: Train loss: {train_loss:.4f}", end = "")

        model.eval()
        valid_loss = 0.0
        for inputs, labels in valid_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            with torch.no_grad():
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            valid_loss += loss.item()
        valid_loss /= len(valid_loader)
        epoch_score = 0.0
        epoch_dice = 0.0
        epoch_iou = 0.0
        epoch_hdict = 0.0
        for inputs, labels in full_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            with torch.no_grad():
                outputs = model(inputs)
                score, dice, iou, hdict = scorer(outputs, labels)
            epoch_score += score.item()
            epoch_dice += dice.item()
            epoch_iou += iou.item()
            epoch_hdict += hdict.item()
        epoch_score /= len(full_loader)
        epoch_dice /= len(full_loader)
        epoch_iou /= len(full_loader)
        epoch_hdict /= len(full_loader)

        print(f", valid loss: {valid_loss:.4f}, score: {epoch_score:.4f}")
        outdata = [train_loss, valid_loss, epoch_score, epoch_dice, epoch_iou, epoch_hdict]
        append_to_csv(log_path, outdata)
        try:
            os.makedirs("./state_dict")
        except:
            pass
        torch.save(model.state_dict(), "./state_dict/{}.pt".format(epoch+1))
        

In [8]:
train()

  1%|          | 1/190 [00:03<10:52,  3.45s/it]

0.0


  1%|          | 2/190 [00:03<05:13,  1.67s/it]

0.0


  2%|▏         | 3/190 [00:04<03:22,  1.08s/it]

0.0


  2%|▏         | 4/190 [00:04<02:30,  1.24it/s]

0.0


  3%|▎         | 5/190 [00:05<02:00,  1.53it/s]

0.0


  3%|▎         | 6/190 [00:05<01:43,  1.79it/s]

0.0


  4%|▎         | 7/190 [00:05<01:31,  1.99it/s]

0.0


  4%|▍         | 8/190 [00:06<01:24,  2.16it/s]

0.0


  5%|▍         | 9/190 [00:06<01:18,  2.29it/s]

0.0


  5%|▌         | 10/190 [00:06<01:15,  2.39it/s]

0.0


  6%|▌         | 11/190 [00:07<01:12,  2.47it/s]

0.0


  6%|▋         | 12/190 [00:07<01:10,  2.52it/s]

0.0


  7%|▋         | 13/190 [00:08<01:09,  2.55it/s]

0.0


  7%|▋         | 14/190 [00:08<01:08,  2.57it/s]

0.0


  8%|▊         | 15/190 [00:08<01:07,  2.58it/s]

0.0


  8%|▊         | 16/190 [00:09<01:06,  2.60it/s]

0.0


  9%|▉         | 17/190 [00:09<01:06,  2.60it/s]

0.0


  9%|▉         | 18/190 [00:09<01:05,  2.61it/s]

0.0


 10%|█         | 19/190 [00:10<01:05,  2.61it/s]

0.0


 11%|█         | 20/190 [00:10<01:05,  2.61it/s]

0.0


 11%|█         | 20/190 [00:11<01:34,  1.80it/s]


KeyboardInterrupt: 