In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.functional as F
import os
from torchvision.datasets import VOCSegmentation
from torch.utils.data import Dataset , DataLoader
import copy
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('./train_logs')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.cuda.empty_cache()

In [3]:
from typing import Any, Tuple
import PIL
PIL.Image.ANTIALIAS = PIL.Image.LANCZOS
from PIL import Image
num_classes = 21
batch_size = 4
num_epoch = 100
class VOCSegDataset(VOCSegmentation):
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        image = Image.open(self.images[index]).convert('RGB')
        label = Image.open(self.targets[index])
        image = self.transform(image)
        label = self.target_transform(label)
        label = (label*255)
        return image , label.long()

In [4]:


image_transforms = transforms.Compose([
                                    transforms.Resize((256,256)),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5) ),
                                     ])
target_transform = transforms.Compose([transforms.Resize((256,256)),
                                     transforms.ToTensor(),
                                     ])
train_dataset = VOCSegDataset('./data',
                             year='2012',download=False ,image_set='train', transform=image_transforms, target_transform=target_transform)
val_dataset = VOCSegDataset('./data',
                           year='2012',download=False, image_set='trainval', transform=image_transforms, target_transform=target_transform)

In [5]:
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=False,num_workers=0)
val_loader = DataLoader(val_dataset,batch_size=batch_size,shuffle=False,num_workers=0)


In [8]:
from UNet import UNet
device = torch.device("cuda")
model = UNet(3,21) # pascalVOC has 21 classes including background
model = model.to(device)


In [9]:
#optimizer
import torch.optim 
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)


In [10]:
# 배치 수 확인
total_batch = len(train_loader)
print(total_batch)


366


In [11]:
import torch.nn as nn


celoss = nn.CrossEntropyLoss(ignore_index=255)

In [12]:
def make_sementic_image(img): # img is 256 256 tensor
    colors = [
    [128, 128, 128],    
    [0, 0, 255],   
    [0, 255, 0],   
    [255, 0, 0],   
    [255, 255, 0], 
    [0, 255, 255], 
    [255, 0, 255], 
    [192, 192, 192], 
    [128, 128, 128], 
    [128, 0, 0],   
    [128, 128, 0], 
    [0, 128, 0],   
    [128, 0, 128], 
    [0, 128, 128], 
    [0, 0, 128],   
    [139, 69, 19],  # 갈색. 사람
    [255, 165, 0],
    [255, 192, 203],
    [255, 255, 255], 
    [255, 105, 180], 
    [240, 230, 140]  
    ]   
    rt = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8)
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            rt[i, j] = colors[img[i, j]]
    return torch.from_numpy(rt).permute(2,0,1)
# Aeroplane	1
# Bicycle	2
# Bird	3
# Boat	4
# Bottle	5
# Bus	6
# Car	7
# Cat	8
# Chair	9
# Cow	10
# Diningtable	11
# Dog	12
# Horse	13
# Motorbike	14
# Person	15
# Pottedplant	16
# Sheep	17
# Sofa	18
# Train	19
# Tvmonitor	20
    

In [13]:
print('wut')
best_val_loss = 50

for epoch in range(num_epoch):
    train_loss = 0.0
    validation_loss =0.0
    validation_dice =0.0
    epoch_dice_scroe = 0.0

    model.train()
    for i ,data in enumerate(train_loader):
        
        inputs,labels = data
        inputs = inputs.to(device)
        inputs.requires_grad_(True)
        labels=labels.squeeze(1)
        labels = (labels).to(device)
 

        outputs = model(inputs)
        outputs_softmax = outputs.softmax(dim=1)

        celossval=celoss(outputs,labels)


   
        labels_no_255 = torch.where(labels >= 255, torch.zeros_like(labels), labels)

        loss = celossval

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss+= loss.item()
        
    print('training runing_loss : ',train_loss/(i+1))
    writer.add_scalar('train_loss',train_loss/(i+1),epoch)
    print('Finised Training... Start Validation.')
    model.eval()
    with torch.no_grad():
        for i,data in enumerate(val_loader):
            inputs,labels = data
            inputs = inputs.to(device)
            labels = labels.squeeze(1)
            labels = (labels).to(device)
            labels.requires_grad_(False)

            outputs = model(inputs)
           
            
            celossval=celoss(outputs,labels)
            
            outputs_softmax = outputs.softmax(dim=1)
           
            labels_no_255 = torch.where(labels >= 255, torch.zeros_like(labels), labels)
            
            loss = celossval
            
            validation_loss+= loss.item()
        print("validation_loss :", validation_loss/(i) )
        writer.add_scalar('validataion loss : ',validation_loss/(i),epoch)
        pred_idx_map = torch.argmax(outputs_softmax,dim=1)
        output_with_color = make_sementic_image(pred_idx_map[0])
        writer.add_image('model_output',output_with_color,epoch)
        writer.add_image('original_image',(inputs[0]*0.5+0.5),epoch)
        
        if best_val_loss > validation_loss/i:
            best_val_loss = validation_loss/i
            print("Saving Best Model")
            model_copy = copy.deepcopy(model).cpu()
            torch.save(model_copy,'./UNet.pth')
    


wut
training runing_loss :  1.4229253199908252
Finised Training... Start Validation.
validation_loss : 1.302691550372721
Saving Best Model
training runing_loss :  1.2662628484391123
Finised Training... Start Validation.
validation_loss : 1.2586691548208615
Saving Best Model
training runing_loss :  1.2291069103231846
Finised Training... Start Validation.
validation_loss : 1.2445706871266549
Saving Best Model
training runing_loss :  1.2052148907888132
Finised Training... Start Validation.
validation_loss : 1.2177973633202224
Saving Best Model
training runing_loss :  1.1895797132957178
Finised Training... Start Validation.
validation_loss : 1.1907126194910034
Saving Best Model
training runing_loss :  1.1782605827180415
Finised Training... Start Validation.
validation_loss : 1.1866759636163056
Saving Best Model
training runing_loss :  1.1695988497447447
Finised Training... Start Validation.
validation_loss : 1.1832208016907775
Saving Best Model
training runing_loss :  1.1594684846251389
Fi

KeyboardInterrupt: 