In [2]:
import os
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
%cd /content/drive/MyDrive/컴퓨터 비전/computer-vision
import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from dataloader.dataset import CustomDataset
from model.unet import UNet
from tool.encode import rle_encode
from tool.transform import fisheye
import albumentations as A
from PIL import Image
from albumentations.pytorch import ToTensorV2

SAVE_ROOT = "/content/drive/MyDrive/컴퓨터 비전/computer-vision/model/checkpoint"
DATA_ROOT = "/content/drive/MyDrive/컴퓨터 비전/data"

def set_dataloader(csv_file, valid = False, batch_size = 16):
    def set_tta():
        transform = A.Compose(
            [
              fisheye([-1, 3.5, 0, 0]),
              A.CenterCrop(600, 930),
              A.Resize(224, 224),
              A.Normalize(),
              ToTensorV2()
            ])
        return transform

    def set_train_aug():
        transform = A.Compose(
            [
              fisheye([-1, 3.5, 0, 0]),
              A.CenterCrop(600, 930),
              A.Resize(224, 224),
              A.Normalize(),
              ToTensorV2()
            ])
        return transform

    if not valid: t = set_train_aug()
    else: t = set_tta()

    dataset = CustomDataset(data_root = DATA_ROOT, csv_file=csv_file, transform= t, infer=False)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    return dataloader

def set_model(load_from):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if not load_from:
        model = UNet().to(device)
    else:
        print("model lodaed...")
        model = torch.load(load_from, map_location=device)
    print("device:", device)
    return model, device

def train():
    load_from = "/content/drive/MyDrive/컴퓨터 비전/computer-vision/model/checkpoint/unet_epoch6_044.pt"
    model, device = set_model(load_from)
    train_dataloader = set_dataloader('train_source.csv')
    valid_dataloader = set_dataloader('val_source.csv')
    model_name = "unet"
    val_every = 5
    max_epoch = 20

    # loss function과 optimizer 정의
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    min_loss = 1e9
    for epoch in range(max_epoch):
        model.train()
        epoch_loss = 0
        for images, masks in tqdm(train_dataloader):
            images = images.float().to(device)
            masks = masks.long().to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks.squeeze(1))
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        train_loss = epoch_loss/len(train_dataloader)
        print(f'Epoch {epoch+1}, Train Loss: {train_loss}')

        # Validation loop
        if epoch==0 or (epoch+1)%val_every==0:
            model.eval()
            with torch.no_grad():
                vloss = 0
                for images, masks in tqdm(valid_dataloader):
                    images = images.float().to(device)
                    masks = masks.long().to(device)

                    outputs = model(images)
                    loss = criterion(outputs, masks.squeeze(1))
                    vloss += loss.item()
                val_loss = vloss / len(valid_dataloader)
                print(f'Valid Loss: {val_loss}')
                torch.save(model, os.path.join(SAVE_ROOT, model_name+f"_epoch{epoch+1+6}.pt"))
                if min_loss > val_loss:
                    torch.save(model, os.path.join(SAVE_ROOT, model_name+"_minimum.pt"))
                    print(f"Minimum Loss!! let's Saved... (Check {model_name}_minimum.pt)")
                    min_loss = val_loss
train()

/content/drive/MyDrive/컴퓨터 비전/computer-vision
model lodaed...
device: cuda


  self.pid = os.fork()
  self.pid = os.fork()
100%|██████████| 138/138 [21:48<00:00,  9.48s/it]


Epoch 6, Train Loss: 0.43870999156564905


100%|██████████| 30/30 [04:24<00:00,  8.83s/it]


Valid Loss: 0.5036573737859726
Minimum Loss!! let's Saved... (Check unet_epoch6_minimum.pt)


100%|██████████| 138/138 [06:39<00:00,  2.89s/it]


Epoch 7, Train Loss: 0.3067058243829271


100%|██████████| 138/138 [06:37<00:00,  2.88s/it]


Epoch 8, Train Loss: 0.27266100524128345


100%|██████████| 138/138 [06:50<00:00,  2.97s/it]


Epoch 9, Train Loss: 0.24638384580612183


100%|██████████| 138/138 [06:53<00:00,  3.00s/it]


Epoch 10, Train Loss: 0.22608739871909653


100%|██████████| 30/30 [01:22<00:00,  2.74s/it]


Valid Loss: 0.4487640529870987
Minimum Loss!! let's Saved... (Check unet_epoch10_minimum.pt)


100%|██████████| 138/138 [06:32<00:00,  2.84s/it]


Epoch 11, Train Loss: 0.21257984465447025


100%|██████████| 138/138 [06:53<00:00,  3.00s/it]


Epoch 12, Train Loss: 0.2084489731469016


100%|██████████| 138/138 [06:57<00:00,  3.02s/it]


Epoch 13, Train Loss: 0.19396280205768088


100%|██████████| 138/138 [06:48<00:00,  2.96s/it]


Epoch 14, Train Loss: 0.1831589668341305


100%|██████████| 138/138 [06:40<00:00,  2.90s/it]


Epoch 15, Train Loss: 0.1807951324659845


100%|██████████| 30/30 [01:19<00:00,  2.64s/it]


Valid Loss: 0.4620408207178116


100%|██████████| 138/138 [06:35<00:00,  2.87s/it]


Epoch 16, Train Loss: 0.16682593821399455


100%|██████████| 138/138 [06:46<00:00,  2.95s/it]


Epoch 17, Train Loss: 0.16577201207046924


100%|██████████| 138/138 [06:42<00:00,  2.92s/it]


Epoch 18, Train Loss: 0.15185740152778832


100%|██████████| 138/138 [06:47<00:00,  2.96s/it]


Epoch 19, Train Loss: 0.1467318802640058


100%|██████████| 138/138 [06:45<00:00,  2.93s/it]


Epoch 20, Train Loss: 0.14268523157722707


100%|██████████| 30/30 [01:19<00:00,  2.66s/it]


Valid Loss: 0.50273524026076


100%|██████████| 138/138 [06:43<00:00,  2.93s/it]


Epoch 21, Train Loss: 0.13302927145707436


100%|██████████| 138/138 [06:40<00:00,  2.90s/it]


Epoch 22, Train Loss: 0.12836873061631038


 59%|█████▊    | 81/138 [04:01<03:24,  3.59s/it]