폴더 경로 설정

In [1]:
workspace_path = '/app/HSK/FL_Seg'  # 파일 업로드한 경로 반영

In [2]:
from tensorboardX import SummaryWriter
summary = SummaryWriter()

### 필요한 패키지 로드

In [3]:
!pip install albumentations==0.4.6
!pip install   yacs

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [4]:
import os
import torch
import torch.nn.functional as F
import torchvision
import yaml
import numpy as np
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
import random
import torch.backends.cudnn as cudnn
import time
import copy
from tqdm import tqdm



### 재구현 세팅

In [5]:
def init_seeds(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
    if seed == 0:  # slower, more reproducible
        cudnn.deterministic = True
        cudnn.benchmark = False
    else:  # faster, less reproducible
        cudnn.deterministic = False
        cudnn.benchmark = True

In [6]:
init_seeds(1)

### 데이터 로드

In [7]:
rgb_path = os.path.join(workspace_path, 'data/train/rgb/')
ngr_path = os.path.join(workspace_path, 'data/train/ngr/')
label_path = os.path.join(workspace_path, 'data/train/label/')

In [8]:
rgb_images = os.listdir(rgb_path)
rgb_images = [os.path.join(rgb_path,x) for x in rgb_images]
ngr_images = os.listdir(ngr_path)
ngr_images = [os.path.join(ngr_path, x) for x in ngr_images]
label_images = os.listdir(label_path)
label_images = [os.path.join(label_path, x) for x in label_images]

### 데이터셋 클래스 정의

In [9]:
class CloudDataset(torch.utils.data.Dataset):
    def __init__(self, image_path, label_path, patch_size = 400, patch_stride = 100, is_train = True, cache_dir = './cache', transforms = None):
        self.image_path = image_path
        self.label_path = label_path
        self.patch_size = patch_size
        self.patch_stride = patch_stride
        self.is_train = is_train
        self.transforms = transforms
        
        self.patch_images = []
        self.patch_labels = []
        
        
        cache_dir = cache_dir
        os.makedirs(cache_dir, exist_ok=True)
        if is_train:
            for img_path in self.image_path:
                img = cv2.imread(img_path)
                img_count = 0
                for x in range(0, img.shape[0]-self.patch_size+1, self.patch_stride):
                    for y in range(0, img.shape[1]-self.patch_size+1, self.patch_stride):
                        patch_image = img[x:x+patch_size, y:y+patch_size, :].copy()
                        patch_path = f'rgb_{os.path.splitext(os.path.basename(img_path))[0]}_{img_count}.png'
                        if not os.path.isfile(os.path.join(cache_dir, patch_path)):
                            cv2.imwrite(os.path.join(cache_dir, patch_path), patch_image)
                        self.patch_images.append(os.path.join(cache_dir, patch_path))
                        img_count += 1

            for label_path in self.label_path:
                img = cv2.imread(label_path)
                img_count = 0
                for x in range(0, img.shape[0]-self.patch_size+1, self.patch_stride):
                    for y in range(0, img.shape[1]-self.patch_size+1, self.patch_stride):
                        patch_image = img[x:x+patch_size, y:y+patch_size, :].copy()
                        patch_path = f'label_{os.path.splitext(os.path.basename(label_path))[0]}_{img_count}.png'
                        if not os.path.isfile(os.path.join(cache_dir, patch_path)):
                            cv2.imwrite(os.path.join(cache_dir, patch_path), patch_image)
                        self.patch_labels.append(os.path.join(cache_dir, patch_path))
                        img_count += 1
        else:
            self.patch_images = self.image_path
            self.patch_labels = self.label_path
    def __len__(self):
        return len(self.patch_images)
        
    def __getitem__(self, idx):
        img = cv2.imread(self.patch_images[idx])
        
        if self.is_train:
            label = cv2.imread(self.patch_labels[idx])
            # numpy arrays to tensors
            h, w = label.shape[:2]
        
            target = np.zeros((h, w), dtype=np.uint8)
            pos = np.where(np.all(label == [0, 0, 255], axis=-1))  # thick cloud
            target[pos] = 1
            pos = np.where(np.all(label == [0, 255, 0], axis=-1))  # thin cloud
            target[pos] = 2
            pos = np.where(np.all(label == [0, 255, 255], axis=-1))  # cloud shadow
            target[pos] = 3
        else:
            target = None
        if self.transforms is not None:
            img, target = self.transforms(img, target)
            
        if self.is_train:
            return img, target
        else:
            return img, self.patch_images[idx]

### 파라미터 세팅

In [10]:
batch_size = 8
epochs = 20

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "2"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
patch_size = 400
patch_stride = 100
num_workers = 0

num_classes = 4
class_names = ['thick cloud', 'thin cloud', 'cloud shadow']

train_data_rate = 0.7

model_name = 'dilated_unet'

loss_func = 'dice'

### 데이터증대

In [11]:
class ImageAug:
    def __init__(self):
        self.aug = A.Compose([A.HorizontalFlip(p=0.5),
                             A.VerticalFlip(p=0.5),
                             A.ShiftScaleRotate(p=0.5),
                             A.RandomBrightnessContrast(p=0.3),
                             A.Normalize(),
                             ToTensorV2()])

    def __call__(self, img, label):
        transformed = self.aug(image=img, mask=label)
        return transformed['image'], transformed['mask']

class DefaultAug:
    def __init__(self):
        self.aug = A.Compose([A.Normalize(),
                             ToTensorV2()])

    def __call__(self, img, label):
        transformed = self.aug(image=img, mask=label)
        return transformed['image'], transformed['mask']

In [12]:
train_transforms = ImageAug()
val_transforms = DefaultAug()

### 데이터셋 정의

In [13]:
num_clients = 5

In [29]:
print(int(len(rgb_images)*train_data_rate*0/num_clients), int(len(rgb_images)*train_data_rate*1/num_clients))
print(int(len(rgb_images)*train_data_rate*1/num_clients), int(len(rgb_images)*train_data_rate*2/num_clients))

0 116
116 232


In [30]:
int(len(rgb_images)*train_data_rate)

580

In [17]:
#train dataset
clients = dict()
for i in range(num_clients):
    if i < (num_clients - 1):
        train_dataset = CloudDataset(rgb_images[int(len(rgb_images)*train_data_rate*i/num_clients):int(len(rgb_images)*train_data_rate*(i+1)/num_clients)], 
                                     label_images[int(len(rgb_images)*train_data_rate*i/num_clients):int(len(label_images)*train_data_rate*(i+1)/num_clients)],
                                    transforms=train_transforms, cache_dir=os.path.join(workspace_path, f'cache_{i}'))
    else:
        train_dataset = CloudDataset(rgb_images[int(len(rgb_images)*train_data_rate/num_clients*i):int(len(rgb_images)*train_data_rate)], 
                                     label_images[int(len(rgb_images)*train_data_rate/num_clients*i):int(len(rgb_images)*train_data_rate)],
                                    transforms=train_transforms, cache_dir=os.path.join(workspace_path, f'cache_{i}'))
        
    clients[f'train_dataloader_{i}'] = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                                                   num_workers=num_workers, pin_memory=True, drop_last=True)

#valid dataset
val_dataset = CloudDataset(rgb_images[int(len(rgb_images)*train_data_rate):], label_images[int(len(label_images)*train_data_rate):],
                            transforms=val_transforms, cache_dir=os.path.join(workspace_path, 'cache_val'))
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True,
                                               num_workers=num_workers, pin_memory=True, drop_last=True)

In [35]:
total = 0
for i in range(num_clients):
    total += len(clients[f'train_dataloader_{i}'])

len_clients = dict()
for i in range(num_clients):   
    len_clients[f'train_dataloader_{i}'] = len(clients[f'train_dataloader_{i}']) / total

### 모델 정의

In [36]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [37]:
import torch.nn as nn

class DoubleConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConvBlock, self).__init__()
        self.block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                                   nn.ReLU(inplace=True),
                                   nn.BatchNorm2d(out_channels),
                                   nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                                   nn.ReLU(inplace=True),
                                   nn.BatchNorm2d(out_channels))

    def forward(self, x):
        x = self.block(x)
        return x


class DilatedConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dilation, padding):
        super(DilatedConvBlock, self).__init__()
        self.block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=padding, dilation=dilation),
                                   nn.ReLU(inplace=True),
                                   nn.BatchNorm2d(out_channels))

    def forward(self, x):
        x = self.block(x)
        return x




class ConcatDoubleConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConcatDoubleConvBlock, self).__init__()
        self.block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                                   nn.ReLU(inplace=True),
                                   nn.BatchNorm2d(out_channels),
                                   nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                                   nn.ReLU(inplace=True),
                                   nn.BatchNorm2d(out_channels))

    def forward(self, x, skip):
        x = torch.cat((skip, x), dim=1)
        x = self.block(x)
        return x



class MyDilatedConvUNet(nn.Module):
    def __init__(self, filters=44, depth=3, bottleneck_depth=6):
        super(MyDilatedConvUNet, self).__init__()
        self.depth = depth
        self.encoder_path = nn.ModuleList()
        src_in_channels = 3     # Geo-TIFF has four channels (R, G, B, and NIR)
        for d in range(depth):
            in_channels = src_in_channels if d == 0 else filters * 2 ** (d-1)
            self.encoder_path.append(
                DoubleConvBlock(in_channels, filters * 2 ** d))
        self.maxpool = nn.MaxPool2d(2, 2, padding=0)
        self.bottleneck_path = nn.ModuleList()
        for d in range(bottleneck_depth):
            in_channels = filters * 2 ** (depth - 1) if d == 0 else filters * 2 ** depth
            self.bottleneck_path.append(DilatedConvBlock(in_channels, filters * 2 ** depth, 2 ** d, 2 ** d))
        self.decoder_path = nn.ModuleList()
        for d in range(depth):
            in_channels = filters * 2 ** (depth - d)
            self.decoder_path.append(ConcatDoubleConvBlock(in_channels, filters * 2 ** (depth - d - 1)))
        self.up_path = nn.ModuleList()
        for d in range(depth):
            in_channels = filters * 2 ** (depth - d)
            self.up_path.append(nn.ConvTranspose2d(in_channels, filters * 2 ** (depth - d - 1),
                                                        kernel_size=4, stride=2, padding=1))
        out_channels = 4     # output channels (num_classes + 1(background))
        self.last_conv = nn.Conv2d(filters, out_channels, kernel_size=1)

    def forward(self, x):
        skip = []
        for block in self.encoder_path:
            x = block(x)
            skip.append(x)
            x = self.maxpool(x)
        dilated = []
        for block in self.bottleneck_path:
            x = block(x)
            dilated.append(x)
        x = torch.stack(dilated, dim=-1).sum(dim=-1)  # sum over list

        # up-sampling and double convolutions
        for d in range(self.depth):
            x = self.up_path[d](x)
            x = self.decoder_path[d](x, skip[-(d+1)])

        return self.last_conv(x)

In [38]:
# Model
if model_name == 'deeplabv3':
    model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=False, progress=True, num_classes=4)

elif model_name == 'dilated_unet':
    model = MyDilatedConvUNet()

model.to(device)

print('number of parameters: ', count_parameters(model))

number of parameters:  9083804


### Opimizer 정의

In [39]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

### 필요 함수 정의

In [40]:
def fitness_test(true, pred, num_classes=4):
    eps = 1e-7
    true_one_hot = F.one_hot(true.squeeze(1), num_classes=num_classes)  # (B, 1, H, W) to (B, H, W, C)
    true_one_hot = true_one_hot.permute(0, 3, 1, 2)  # (B, H, W, C) to (B, C, H, W)
    pred_max = pred.argmax(1)      # (B, C, H, W) to (B, H, W)
    pix_acc = (true == pred_max.unsqueeze(1)).sum().float().div(true.nelement())
    pred_one_hot = F.one_hot(pred_max, num_classes=num_classes)   # (B, H, W) to (B, H, W, C)
    pred_one_hot = pred_one_hot.permute(0, 3, 1, 2)   # (B, H, W, C) to (B, C, H, W)

    true_one_hot = true_one_hot.type(pred_one_hot.type())
    dims = (0,) + tuple(range(2, true.ndimension()))  # dims = (0, 2, 3)
    intersection = torch.sum(pred_one_hot & true_one_hot, dims)
    union = torch.sum(pred_one_hot | true_one_hot, dims)
    m_iou = (intersection / (union + eps)).mean()

    return m_iou.item(), pix_acc.item()

In [41]:
# Loss 함수 정의
def ce_loss(true, logits, ignore=255):
    """Computes the weighted multi-class cross-entropy loss.
    Args:
        true: a tensor of shape [B, 1, H, W].
        logits: a tensor of shape [B, C, H, W]. Corresponds to
            the raw output or logits of the model.
        ignore: the class index to ignore.
    Returns:
        ce_loss: the weighted multi-class cross-entropy loss.
    """
    ce_loss = F.cross_entropy(
        logits.float(),
        true.squeeze(1).long(),    # [B, H, W]
        ignore_index=ignore,
    )
    return ce_loss


def dice_loss(true, logits, eps=1e-7):
    """Computes the Sørensen–Dice loss.
    Note that PyTorch optimizers minimize a loss. In this
    case, we would like to maximize the dice loss so we
    return the negated dice loss.
    Args:
        true: a tensor of shape [B, 1, H, W].
        logits: a tensor of shape [B, C, H, W]. Corresponds to
            the raw output or logits of the model.
        eps: added to the denominator for numerical stability.
    Returns:
        dice_loss: the Sørensen–Dice loss.
    """
    num_classes = logits.shape[1]
    if num_classes == 1:
        true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        true_1_hot_f = true_1_hot[:, 0:1, :, :]
        true_1_hot_s = true_1_hot[:, 1:2, :, :]
        true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
        pos_prob = torch.sigmoid(logits)
        neg_prob = 1 - pos_prob
        probas = torch.cat([pos_prob, neg_prob], dim=1)
    else:
        # true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
        true_1_hot = F.one_hot(true.squeeze(1), num_classes=num_classes)   # (B, 1, H, W) to (B, H, W, C)
        true_1_hot = true_1_hot.permute(0, 3, 1, 2)                        # (B, H, W, C) to (B, C, H, W)
        probas = F.softmax(logits, dim=1)
    true_1_hot = true_1_hot.type(logits.type()).contiguous()
    dims = (0,) + tuple(range(2, true.ndimension()))        # dims = (0, 2, 3)
    intersection = torch.sum(probas * true_1_hot, dims)     # intersection w.r.t. the class
    cardinality = torch.sum(probas + true_1_hot, dims)      # cardinality w.r.t. the class
    dice_loss = (2. * intersection / (cardinality + eps)).mean()
    return (1 - dice_loss)


def jaccard_loss(true, logits, eps=1e-7):
    """Computes the Jaccard loss, a.k.a the IoU loss.
    Note that PyTorch optimizers minimize a loss. In this
    case, we would like to maximize the jaccard loss so we
    return the negated jaccard loss.
    Args:
        true: a tensor of shape [B, 1, H, W].
        logits: a tensor of shape [B, C, H, W]. Corresponds to
            the raw output or logits of the model.
        eps: added to the denominator for numerical stability.
    Returns:
        jacc_loss: the Jaccard loss.
    """
    num_classes = logits.shape[1]
    if num_classes == 1:
        true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        true_1_hot_f = true_1_hot[:, 0:1, :, :]
        true_1_hot_s = true_1_hot[:, 1:2, :, :]
        true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
        pos_prob = torch.sigmoid(logits)
        neg_prob = 1 - pos_prob
        probas = torch.cat([pos_prob, neg_prob], dim=1)
    else:
        true_1_hot = F.one_hot(true.squeeze(1), num_classes=num_classes)  # (B, 1, H, W) to (B, H, W, C)
        true_1_hot = true_1_hot.permute(0, 3, 1, 2)  # (B, H, W, C) to (B, C, H, W)
        probas = F.softmax(logits, dim=1)
    true_1_hot = true_1_hot.type(logits.type()).contiguous()
    dims = (0,) + tuple(range(2, true.ndimension()))
    intersection = torch.sum(probas * true_1_hot, dims)
    cardinality = torch.sum(probas + true_1_hot, dims)
    union = cardinality - intersection
    jacc_loss = (intersection / (union + eps)).mean()
    return (1 - jacc_loss)

### 학습 함수 정의

In [43]:
for train_dataloader in clients:
    print(len(clients[train_dataloader]))

710
710
710
710
710


In [44]:
len(clients[train_dataloader])

710

In [76]:
local_model = MyDilatedConvUNet()
for train_dataloader in clients:
    for i in local_model.state_dict():
        print(local_model.state_dict()[i])
        local_model.state_dict()[i] *= 2
        print(local_model.state_dict()[i])

tensor([[[[-0.1121,  0.1000,  0.0651],
          [-0.1140, -0.0559,  0.0525],
          [ 0.1330,  0.0794, -0.0128]],

         [[ 0.0955,  0.0332, -0.0098],
          [-0.1388, -0.0287, -0.0406],
          [-0.0275, -0.1667, -0.0098]],

         [[ 0.0918,  0.0931,  0.0031],
          [-0.1211,  0.0089, -0.0383],
          [-0.1843, -0.1456, -0.1355]]],


        [[[-0.0994, -0.1539, -0.1113],
          [ 0.0611,  0.0871, -0.0271],
          [ 0.0273,  0.1115,  0.1659]],

         [[ 0.1305,  0.0006, -0.0379],
          [-0.0285, -0.0161,  0.0318],
          [-0.0276,  0.1168,  0.0682]],

         [[-0.1095, -0.0154, -0.1545],
          [-0.0638, -0.0450, -0.0718],
          [ 0.1090,  0.1858,  0.1699]]],


        [[[ 0.0687, -0.1547, -0.1093],
          [-0.1857,  0.0932,  0.1748],
          [-0.0359, -0.1236,  0.1422]],

         [[ 0.1520,  0.0731,  0.0058],
          [-0.1456, -0.1808, -0.0328],
          [-0.1590,  0.0614, -0.1722]],

         [[-0.1476, -0.0477, -0.1462],
     

tensor([-1.3737e-02,  4.4029e-03,  1.0935e-02, -7.1441e-03, -8.0664e-03,
         1.5785e-02, -1.1788e-02, -1.4509e-04, -9.4785e-03,  1.3816e-02,
        -1.0676e-02,  6.7168e-03,  1.5907e-02,  3.2962e-03,  1.0288e-02,
         2.7579e-03, -1.8298e-03, -6.4192e-03, -6.7257e-03, -3.6076e-03,
        -8.8759e-03,  1.9518e-03,  4.3722e-03,  1.6760e-02, -5.4294e-03,
         1.0325e-02,  2.1034e-03, -1.0996e-02, -2.9579e-03,  1.4913e-02,
        -7.0840e-03,  1.7211e-02,  1.1550e-02,  7.7740e-03, -1.6423e-02,
         1.4279e-02,  9.9129e-04,  1.3729e-02, -5.6713e-03, -7.4484e-04,
         8.8296e-03, -8.3219e-04, -9.4775e-03, -4.1037e-03, -8.8924e-04,
         1.4195e-02, -4.3263e-03, -9.8101e-03,  1.0640e-02,  1.3163e-02,
        -9.7208e-03, -4.3686e-03, -6.3678e-03,  1.3575e-02, -7.5941e-03,
        -1.7188e-02,  3.8445e-03,  7.1316e-04, -7.8947e-04, -6.4813e-03,
        -1.2067e-03, -7.7717e-03, -1.7454e-02, -3.9208e-03, -1.3348e-02,
        -1.5360e-02,  5.1024e-04,  1.2179e-02,  1.0

tensor(0)
tensor(0)
tensor([[[[-0.0237,  0.0574, -0.0230],
          [-0.0563, -0.0339, -0.0914],
          [-0.0460, -0.0635, -0.0934]],

         [[ 0.0590, -0.0866,  0.0861],
          [-0.0996, -0.0900,  0.0322],
          [ 0.0461, -0.0246, -0.0235]],

         [[-0.0477,  0.0679,  0.0113],
          [-0.1004, -0.0661, -0.0052],
          [ 0.0383, -0.0048, -0.0674]],

         ...,

         [[ 0.0948, -0.0286,  0.0351],
          [ 0.0852,  0.0925,  0.0220],
          [-0.0922,  0.0416,  0.0160]],

         [[-0.0732, -0.0497, -0.0450],
          [ 0.0340,  0.0219,  0.0106],
          [-0.0213,  0.0762, -0.0148]],

         [[ 0.0059, -0.1001, -0.0542],
          [ 0.0004,  0.0979, -0.0654],
          [ 0.0529,  0.0280,  0.0887]]],


        [[[-0.0552, -0.0580,  0.0929],
          [ 0.0547, -0.0944, -0.0193],
          [-0.0649,  0.0306, -0.0436]],

         [[-0.0024, -0.0735, -0.0686],
          [ 0.0392,  0.0691,  0.0396],
          [ 0.0123, -0.0879, -0.0546]],

         [[

tensor([-0.0420, -0.0589,  0.0028,  0.0343, -0.0395,  0.0168,  0.0642, -0.0015,
        -0.0011,  0.0389,  0.0523,  0.0690,  0.0424,  0.0414,  0.0153, -0.0216,
        -0.0070, -0.0482,  0.0327,  0.0134,  0.0434,  0.0255,  0.0635, -0.0097,
        -0.0433,  0.0563, -0.0309,  0.0410, -0.0612,  0.0044, -0.0598,  0.0554,
         0.0003,  0.0407, -0.0421, -0.0624, -0.0092,  0.0297, -0.0232,  0.0327,
        -0.0688, -0.0337,  0.0618,  0.0661,  0.0354, -0.0086, -0.0688, -0.0431,
        -0.0541,  0.0071, -0.0461,  0.0590,  0.0193, -0.0664,  0.0667, -0.0309,
        -0.0578,  0.0280, -0.0442, -0.0372, -0.0550, -0.0142,  0.0072,  0.0228,
         0.0439, -0.0623,  0.0015, -0.0289, -0.0286,  0.0694,  0.0034, -0.0058,
        -0.0322,  0.0163, -0.0486,  0.0467, -0.0553, -0.0023, -0.0478,  0.0419,
        -0.0388,  0.0435,  0.0150,  0.0159,  0.0433,  0.0388,  0.0527, -0.0227,
        -0.0465,  0.0453,  0.0674, -0.0250,  0.0313, -0.0175, -0.0135,  0.0086,
         0.0405,  0.0293,  0.0045,  0.01

tensor([-0.1205,  0.0942, -0.1777,  0.2200,  0.0910, -0.1250, -0.1490,  0.2726,
         0.0837,  0.1782,  0.0518,  0.0826,  0.1542, -0.2621, -0.0822, -0.1437,
         0.2731,  0.2281,  0.2479, -0.2261, -0.0097,  0.2407, -0.1444, -0.2765,
        -0.1183, -0.2683,  0.2071,  0.0261,  0.0175, -0.2656,  0.1881, -0.2017,
         0.2643,  0.0862,  0.1063, -0.1656,  0.1803, -0.2608, -0.0570, -0.1827,
         0.0225, -0.1120,  0.2703, -0.2110,  0.2633,  0.2111, -0.2114, -0.0161,
         0.1603, -0.2591,  0.1207, -0.0883,  0.0399,  0.2664, -0.0596, -0.1798,
        -0.1542, -0.2451,  0.1138, -0.1719,  0.2484, -0.1422,  0.1216, -0.2100,
        -0.1948,  0.2490,  0.1792,  0.2386, -0.2549,  0.1606, -0.0653,  0.1709,
        -0.1551, -0.1254,  0.0914,  0.0956,  0.0363,  0.1942, -0.2569, -0.0711,
        -0.1310, -0.2275,  0.0691, -0.2827, -0.1741,  0.0615, -0.1465, -0.1906])
tensor([4., 4., 4., 4., 4., 4., 4., 4., 4., 4., 4., 4., 4., 4., 4., 4., 4., 4.,
        4., 4., 4., 4., 4., 4., 4., 4.,

tensor(0)
tensor([[[[ 2.9768e-02,  6.8263e-02,  5.0451e-02],
          [ 4.1703e-02,  3.8764e-02,  6.0549e-03],
          [-3.7892e-03, -6.2453e-02,  1.2714e-02]],

         [[-1.7725e-03, -2.9610e-02, -4.6497e-02],
          [-2.3763e-02, -2.4623e-02, -6.6976e-02],
          [ 1.1895e-02,  4.9203e-02, -7.0755e-02]],

         [[-4.4184e-02,  1.5522e-02,  5.4839e-02],
          [-1.0458e-02,  4.1194e-02, -1.9989e-02],
          [-5.9059e-02,  3.4421e-03,  1.4062e-02]],

         ...,

         [[ 1.3799e-02,  4.5771e-02,  5.0259e-02],
          [ 3.5960e-02,  2.8327e-03,  4.4426e-03],
          [ 2.5754e-03, -2.2999e-02, -5.1244e-02]],

         [[-4.3642e-02,  3.1745e-02,  1.4710e-03],
          [ 2.5774e-02,  6.9948e-02, -1.4093e-02],
          [-6.8126e-02, -5.8302e-02,  2.4687e-03]],

         [[ 1.2469e-02, -3.9609e-02,  5.5122e-02],
          [-1.5777e-02,  4.4616e-02, -3.7667e-02],
          [ 2.3890e-02,  2.9662e-02, -1.5248e-02]]],


        [[[ 3.1128e-02, -3.3588e-02, -1.208

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.,

tensor([ 0.3097, -0.0539, -0.2269, -0.2259, -0.2966, -0.3394,  0.3684, -0.1054,
        -0.0597,  0.1469,  0.1002,  0.1132,  0.0324, -0.3308, -0.0433,  0.1783,
         0.1558, -0.3263, -0.0163, -0.2160, -0.4009, -0.1428, -0.0937,  0.2500,
         0.1557,  0.0229,  0.3702,  0.2396,  0.3328,  0.0885,  0.1357,  0.2914,
        -0.0578,  0.3164,  0.1876, -0.1952, -0.1810, -0.2185, -0.1677, -0.0811,
        -0.3460,  0.4011, -0.2766,  0.2985, -0.0660,  0.2025,  0.0026, -0.3894,
        -0.3507, -0.3714, -0.2985,  0.3868,  0.3148,  0.2054, -0.0330,  0.2753,
        -0.0160, -0.0235, -0.1893, -0.3736, -0.0375,  0.0734, -0.2520, -0.0480,
        -0.2370, -0.3645, -0.2965, -0.3011,  0.1716,  0.1219,  0.1728,  0.0801,
         0.0080, -0.1206, -0.3342, -0.0092,  0.0396,  0.1861, -0.2666, -0.3334,
         0.0959,  0.4017,  0.1036,  0.3166,  0.3055,  0.0077, -0.0887, -0.3412,
        -0.0664,  0.3868,  0.0738, -0.0164,  0.1106,  0.3112,  0.0807,  0.0992,
         0.1443,  0.1812, -0.0110, -0.12

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.,

tensor([[[[ 0.5012,  0.2094,  0.6913],
          [-0.7859, -0.6771,  0.5059],
          [-0.2901, -0.1571, -0.6883]],

         [[ 0.5271,  0.4591, -0.7238],
          [ 0.7554, -0.5805,  0.6534],
          [ 0.1916, -0.3888, -0.7399]],

         [[ 0.0161, -0.3267,  0.0137],
          [-0.5867,  0.0230, -0.4132],
          [ 0.0679,  0.4986,  0.5326]],

         ...,

         [[ 0.3576, -0.2793, -0.3154],
          [-0.1169, -0.2709, -0.3781],
          [ 0.4854, -0.0093, -0.7520]],

         [[-0.0169, -0.5336, -0.3139],
          [ 0.0384,  0.7678, -0.6331],
          [-0.2482,  0.2446,  0.2429]],

         [[ 0.5872, -0.2762,  0.5719],
          [-0.1987,  0.4018, -0.6254],
          [-0.5686,  0.1146, -0.5677]]],


        [[[ 0.3468,  0.0318,  0.1191],
          [-0.2245, -0.1046, -0.6002],
          [-0.6198, -0.0040,  0.6253]],

         [[ 0.7809, -0.5477,  0.2359],
          [-0.0958, -0.1528, -0.6688],
          [-0.6265,  0.4625,  0.1235]],

         [[ 0.3724, -0.3798, -0

In [62]:
len_clients['train_dataloader_0']

0.2

In [83]:
def train(optimizer, clients, val_dataloader, loss_func, epochs, device, patch_size=400, use_scheduler=False, save_path='./ckpt'):

    # Learning rate scheduler
    if use_scheduler:
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=1)
    else:
        lr_scheduler = None
        
    start_epoch = 0
    resume = True

    if not os.path.isdir(save_path):
        os.mkdir(save_path)

    weight_file = save_path + '/{}.pt'.format(model_name)

    best_fit = 0.0
    num_epochs = epochs
    
    global_model = MyDilatedConvUNet()
    
    if resume:
        if os.path.exists(weight_file):
            checkpoint = torch.load(weight_file)
            model.load_state_dict(checkpoint['model'])
            start_epoch = checkpoint['epoch'] + 1
            best_fit = checkpoint['best_fit']
            print("Starting training for %g epochs..." % start_epoch)

    # Start training

    for epoch in range(start_epoch, num_epochs):
        # loss, metric = train_one_epoch(model, optimizer, dataloader, device, epoch)
        t0 = time.time()
        
        clients_loss = []        
        
        for train_dataloader in clients:
            local_model = MyDilatedConvUNet()
            local_model.load_state_dict(global_model.state_dict())
            
            if train_dataloader == 'train_dataloader_0':
                loss = train_one_epoch(local_model, optimizer, clients[train_dataloader], loss_func, device, epoch, num_epochs)
                clients_loss.append(loss.item)
                weights = copy.deepcopy(local_model.state_dict())
                
                for p in weights:
                    weights[p] = weights[p] * len_clients[train_dataloader]
                    
            else:
                loss = train_one_epoch(local_model, optimizer, clients[train_dataloader], loss_func, device, epoch, num_epochs)
                clients_loss.append(loss.item)
                
                for p in local_model.state_dict():
                    local_model.state_dict()[p] = local_model.state_dict()[p] * len_clients[train_dataloader]
                    weights[p] += local_model.state_dict()[p]
            
        t1 = time.time()
#         print('[Epoch %g] loss=%.4f, time=%.1f' % (epoch, sum(clients_loss) / len(clients_loss), t1 - t0)) ## 오류
        
        global_model.load_state_dict(weights)
        
        if lr_scheduler is not None:
            lr_scheduler.step(loss)
        #tb_writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)
        
        state = {'model_name': model_name, 'epoch': epoch, 'best_fit': best_fit, 'model': global_model.state_dict()}
        torch.save(state, weight_file)

        #tb_writer.add_scalar('train_epoch_loss', loss, epoch)
        
        torch.save(state, save_path + '/{}.pt'.format(model_name)) 
        
    # validation
    patch_size = patch_size
    fit = val_one_epoch(global_model, val_dataloader, device, epoch, num_epochs, patch_size)
    if fit > best_fit:
        print("best fit so far=>saved")
        torch.save(state, save_path + '/{}_{}_best.pt'.format(model_name, epoch))
        best_fit = fit
    torch.save(state, save_path + '/{}_{}.pt'.format(model_name, epoch))    
#         writer.add_scalar('D_loss_adv', d_loss_adv.item(), epoch)
#         writer.add_scalar('D_loss_cls', d_loss_cls.item(), epoch)


def train_one_epoch(model, optimizer, data_loader, loss_func, device, epoch, num_epochs):
    model.to(device)
    model.train()
    losses = np.array([])
    metrics = np.array([])
    bi0 = epoch * len(data_loader)  # batch index

    print(('\n' + '%10s' * 2) % ('Epoch', 'loss'))
    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    s = ('%10s' + '%10.4f') % (
        '-/%g' % (num_epochs - 1), 0.0)
    pbar.set_description(s)
    for i, (imgs, targets) in pbar:
        imgs, targets = imgs.to(device), targets.to(device)
        if model_name == 'deeplabv3':
            preds = model(imgs)['out']
            targets = targets.long()
        elif model_name == 'hrnet_w18' or model_name == 'hrnet_w48':
            preds = model(imgs)
            h, w = preds.shape[2], preds.shape[3]
            targets = F.interpolate(targets.float(), size=(h, w), mode='nearest').long()
        elif model_name == 'dilated_unet':
            preds = model(imgs)
            targets = targets.long()
            
        if loss_func == 'jaccard':
            loss = jaccard_loss(targets, preds)
        elif loss_func == 'dice':
            loss = dice_loss(targets, preds)
        elif loss_func == 'ce':
            loss = ce_loss(targets, preds)
        else:
            print('unsupported loss function')
            exit(1)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            # cv2_imshow(imgs[0], preds[0])
            losses = np.append(losses, loss.item())

            s = ('%10s' + '%10.4f') % (
                '%g/%g' % (epoch, num_epochs - 1), loss.item())
            pbar.set_description(s)
            bi = bi0 + i
            #tb_writer.add_scalar('train_batch_loss', loss.item(), bi)

    epoch_loss = losses.mean()

    return epoch_loss


def val_one_epoch(model, data_loader, device, epoch, num_epochs, patch_size):
    model.eval()
    m_iou_list = np.array([])
    pix_acc_list = np.array([])

    print(('\n' + '%10s' * 3) % ('Epoch(V)', 'mIOU', 'Accuracy'))
    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    s = ('%10s' + '%10.4f' + ' %8.4f') % (
        '-/%g' % (num_epochs - 1), 0.0, 0.0)
    pbar.set_description(s)

    for i, (imgs, targets) in pbar:
        imgs, targets = imgs.to(device), targets.to(device)
        with torch.no_grad():
            if model_name == 'deeplabv3':
                preds = model(imgs)['out']
                targets = targets.long()
            elif model_name == 'hrnet_w18' or model_name == 'hrnet_w48':
                preds = model(imgs)
                h, w = preds.shape[2], preds.shape[3]
                targets = F.interpolate(targets.float(), size=(h, w), mode='nearest').long()
            elif model_name == 'dilated_unet':
                preds = model(imgs)
                targets = targets.long()

            m_iou, pix_acc = fitness_test(targets, preds)

            s = ('%10s' + '%10.4f' + ' %8.4f') % (
                '%g/%g' % (epoch, num_epochs - 1), m_iou, pix_acc)
            pbar.set_description(s)
            m_iou_list = np.append(m_iou_list, m_iou)
            pix_acc_list = np.append(pix_acc_list, pix_acc)
    val_m_iou_mean = m_iou_list.mean()
    val_pix_acc_mean = pix_acc_list.mean()
    print('[V] mIOU={:.3f}, Accuracy={:.3f}'.format(val_m_iou_mean, val_pix_acc_mean))
    #tb_writer.add_scalar('val_epoch_m_iou', val_m_iou_mean, epoch)
    #tb_writer.add_scalar('val_epoch_pix_acc', val_pix_acc_mean, epoch)
    return val_pix_acc_mean


### 학습 시작

In [84]:
train(optimizer, clients, val_dataloader, loss_func, epochs, device, patch_size=patch_size, save_path='/app/HSK/FL_Seg/codes/ckpt_FL')


     Epoch      loss


      0/19    0.7835: 100%|██████████| 710/710 [04:38<00:00,  2.55it/s]



     Epoch      loss


      0/19    0.7758: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      0/19    0.7522: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      0/19    0.7685: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      0/19    0.8098: 100%|██████████| 710/710 [04:40<00:00,  2.53it/s]



     Epoch      loss


      1/19    0.8014: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      1/19    0.7888: 100%|██████████| 710/710 [04:40<00:00,  2.53it/s]



     Epoch      loss


      1/19    0.7459: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      1/19    0.7665: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      1/19    0.8309: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      2/19    0.7936: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      2/19    0.7902: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      2/19    0.8014: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      2/19    0.7816: 100%|██████████| 710/710 [04:40<00:00,  2.54it/s]



     Epoch      loss


      2/19    0.7934: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      3/19    0.7733: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      3/19    0.7846: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      3/19    0.7850: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      3/19    0.7634: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      3/19    0.7760: 100%|██████████| 710/710 [04:40<00:00,  2.53it/s]



     Epoch      loss


      4/19    0.7693: 100%|██████████| 710/710 [04:40<00:00,  2.53it/s]



     Epoch      loss


      4/19    0.7714: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      4/19    0.7652: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      4/19    0.7730: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      4/19    0.7828: 100%|██████████| 710/710 [04:40<00:00,  2.53it/s]



     Epoch      loss


      5/19    0.7855: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      5/19    0.8287: 100%|██████████| 710/710 [04:38<00:00,  2.55it/s]



     Epoch      loss


      5/19    0.7667: 100%|██████████| 710/710 [04:40<00:00,  2.53it/s]



     Epoch      loss


      5/19    0.7818: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      5/19    0.7778: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      6/19    0.7714: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      6/19    0.7774: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      6/19    0.7742: 100%|██████████| 710/710 [04:40<00:00,  2.54it/s]



     Epoch      loss


      6/19    0.7731: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      6/19    0.8199: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


      7/19    0.8119: 100%|██████████| 710/710 [04:37<00:00,  2.56it/s]



     Epoch      loss


      7/19    0.8007: 100%|██████████| 710/710 [04:37<00:00,  2.56it/s]



     Epoch      loss


      7/19    0.7601: 100%|██████████| 710/710 [04:37<00:00,  2.56it/s]



     Epoch      loss


      7/19    0.7862: 100%|██████████| 710/710 [04:37<00:00,  2.56it/s]



     Epoch      loss


      7/19    0.7944: 100%|██████████| 710/710 [04:37<00:00,  2.56it/s]



     Epoch      loss


      8/19    0.7624: 100%|██████████| 710/710 [04:36<00:00,  2.57it/s]



     Epoch      loss


      8/19    0.7804: 100%|██████████| 710/710 [04:36<00:00,  2.57it/s]



     Epoch      loss


      8/19    0.7779: 100%|██████████| 710/710 [04:36<00:00,  2.56it/s]



     Epoch      loss


      8/19    0.7978: 100%|██████████| 710/710 [04:36<00:00,  2.56it/s]



     Epoch      loss


      8/19    0.7834: 100%|██████████| 710/710 [04:36<00:00,  2.57it/s]



     Epoch      loss


      9/19    0.7695: 100%|██████████| 710/710 [04:36<00:00,  2.57it/s]



     Epoch      loss


      9/19    0.8095: 100%|██████████| 710/710 [04:36<00:00,  2.57it/s]



     Epoch      loss


      9/19    0.7589: 100%|██████████| 710/710 [04:37<00:00,  2.56it/s]



     Epoch      loss


      9/19    0.8016: 100%|██████████| 710/710 [04:36<00:00,  2.57it/s]



     Epoch      loss


      9/19    0.7804: 100%|██████████| 710/710 [04:36<00:00,  2.57it/s]



     Epoch      loss


     10/19    0.9319: 100%|██████████| 710/710 [04:35<00:00,  2.58it/s]



     Epoch      loss


     10/19    0.9315: 100%|██████████| 710/710 [04:35<00:00,  2.58it/s]



     Epoch      loss


     10/19    0.8889: 100%|██████████| 710/710 [04:36<00:00,  2.57it/s]



     Epoch      loss


     10/19    0.8964: 100%|██████████| 710/710 [04:35<00:00,  2.57it/s]



     Epoch      loss


     10/19    0.9385: 100%|██████████| 710/710 [04:35<00:00,  2.58it/s]



     Epoch      loss


     11/19    0.9171: 100%|██████████| 710/710 [04:36<00:00,  2.57it/s]



     Epoch      loss


     11/19    0.9072: 100%|██████████| 710/710 [04:34<00:00,  2.59it/s]



     Epoch      loss


     11/19    0.9259: 100%|██████████| 710/710 [04:35<00:00,  2.58it/s]



     Epoch      loss


     11/19    0.9517: 100%|██████████| 710/710 [04:34<00:00,  2.58it/s]



     Epoch      loss


     11/19    0.9269: 100%|██████████| 710/710 [04:35<00:00,  2.58it/s]



     Epoch      loss


     12/19    0.9442: 100%|██████████| 710/710 [04:35<00:00,  2.58it/s]



     Epoch      loss


     12/19    0.9681: 100%|██████████| 710/710 [04:58<00:00,  2.38it/s]



     Epoch      loss


     12/19    0.9032: 100%|██████████| 710/710 [06:18<00:00,  1.87it/s]



     Epoch      loss


     12/19    0.9347: 100%|██████████| 710/710 [04:46<00:00,  2.48it/s]



     Epoch      loss


     12/19    0.8980: 100%|██████████| 710/710 [04:49<00:00,  2.45it/s]



     Epoch      loss


     13/19    0.9086: 100%|██████████| 710/710 [04:55<00:00,  2.41it/s]



     Epoch      loss


     13/19    0.8876: 100%|██████████| 710/710 [04:48<00:00,  2.46it/s]



     Epoch      loss


     13/19    0.9888: 100%|██████████| 710/710 [04:46<00:00,  2.48it/s]



     Epoch      loss


     13/19    0.9588: 100%|██████████| 710/710 [05:28<00:00,  2.16it/s]



     Epoch      loss


     13/19    0.9320: 100%|██████████| 710/710 [05:01<00:00,  2.35it/s]



     Epoch      loss


     14/19    0.9673: 100%|██████████| 710/710 [04:52<00:00,  2.43it/s]



     Epoch      loss


     14/19    0.9159: 100%|██████████| 710/710 [04:53<00:00,  2.42it/s]



     Epoch      loss


     14/19    0.9927: 100%|██████████| 710/710 [04:53<00:00,  2.42it/s]



     Epoch      loss


     14/19    0.8824: 100%|██████████| 710/710 [04:46<00:00,  2.48it/s]



     Epoch      loss


     14/19    0.9299: 100%|██████████| 710/710 [05:24<00:00,  2.19it/s]



     Epoch      loss


     15/19    0.8198: 100%|██████████| 710/710 [04:51<00:00,  2.44it/s]



     Epoch      loss


     15/19    0.8345: 100%|██████████| 710/710 [04:50<00:00,  2.44it/s]



     Epoch      loss


     15/19    0.8311: 100%|██████████| 710/710 [04:52<00:00,  2.43it/s]



     Epoch      loss


     15/19    0.8447: 100%|██████████| 710/710 [04:50<00:00,  2.44it/s]



     Epoch      loss


     15/19    0.8630: 100%|██████████| 710/710 [05:03<00:00,  2.34it/s]



     Epoch      loss


     16/19       nan: 100%|██████████| 710/710 [05:21<00:00,  2.21it/s]



     Epoch      loss


     16/19       nan: 100%|██████████| 710/710 [04:37<00:00,  2.56it/s]



     Epoch      loss


     16/19       nan: 100%|██████████| 710/710 [04:40<00:00,  2.53it/s]



     Epoch      loss


     16/19       nan: 100%|██████████| 710/710 [04:40<00:00,  2.53it/s]



     Epoch      loss


     16/19       nan: 100%|██████████| 710/710 [04:39<00:00,  2.54it/s]



     Epoch      loss


     17/19       nan: 100%|██████████| 710/710 [04:40<00:00,  2.54it/s]



     Epoch      loss


     17/19       nan: 100%|██████████| 710/710 [04:36<00:00,  2.57it/s]



     Epoch      loss


     17/19       nan: 100%|██████████| 710/710 [04:36<00:00,  2.56it/s]



     Epoch      loss


     17/19       nan: 100%|██████████| 710/710 [04:37<00:00,  2.56it/s]



     Epoch      loss


     17/19       nan: 100%|██████████| 710/710 [04:36<00:00,  2.57it/s]



     Epoch      loss


     18/19       nan: 100%|██████████| 710/710 [04:36<00:00,  2.57it/s]



     Epoch      loss


     18/19       nan: 100%|██████████| 710/710 [04:36<00:00,  2.57it/s]



     Epoch      loss


     18/19       nan: 100%|██████████| 710/710 [04:36<00:00,  2.56it/s]



     Epoch      loss


     18/19       nan: 100%|██████████| 710/710 [04:36<00:00,  2.57it/s]



     Epoch      loss


     18/19       nan: 100%|██████████| 710/710 [04:37<00:00,  2.56it/s]



     Epoch      loss


     19/19       nan: 100%|██████████| 710/710 [04:37<00:00,  2.56it/s]



     Epoch      loss


     19/19       nan: 100%|██████████| 710/710 [04:37<00:00,  2.56it/s]



     Epoch      loss


     19/19       nan: 100%|██████████| 710/710 [04:37<00:00,  2.56it/s]



     Epoch      loss


     19/19       nan: 100%|██████████| 710/710 [04:37<00:00,  2.56it/s]



     Epoch      loss


     19/19       nan: 100%|██████████| 710/710 [04:36<00:00,  2.57it/s]



  Epoch(V)      mIOU  Accuracy


      -/19    0.0000   0.0000:   0%|          | 0/1525 [00:00<?, ?it/s]


RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

### 최고 성능 모델 로드

In [25]:
save_path=os.path.join(workspace_path, 'codes/ckpt_deepv3')

checkpoint_path = os.path.join(save_path,'{}_best.pt'.format(model_name))
checkpoint = torch.load(checkpoint_path)

model.load_state_dict(checkpoint['model'])
model.to(device)

print('model load success')

model load success


### 테스트 데이터셋 정의

In [26]:
test_rgb_path = os.path.join(workspace_path, 'data/test/rgb')
test_rgb_images = os.listdir(test_rgb_path)
test_rgb_images = [os.path.join(test_rgb_path, x) for x in test_rgb_images]

In [27]:
#empty value
test_label_path = os.path.join(workspace_path, 'data/test/label')
try:
    test_label_images = os.listdir(test_label_path)
except:
    test_label_images = []
test_label_images = [os.path.join(test_label_path, x) for x in test_label_images]

In [28]:
test_dataset = CloudDataset(test_rgb_images, test_label_images,
                            transforms=val_transforms, is_train=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False,
                                               num_workers=num_workers, pin_memory=True, drop_last=True)

### 테스트 결과 저장

In [29]:
model.eval()

result_path = os.path.join(workspace_path, 'results')
os.makedirs(result_path, exist_ok=True)

with torch.no_grad():
    pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader))
    for i, (imgs, img_path) in pbar:
        imgs = imgs.to(device)
        if model_name == 'deeplabv3':
            preds = model(imgs)['out']
        #elif model_name == 'hrnet_w18' or model_name == 'hrnet_w48':
        #    preds = model(imgs)
        #    h, w = preds.shape[2], preds.shape[3]
        elif model_name == 'dilated_unet':
            preds = model(imgs)
        
        pred_img = np.zeros((*list(preds.shape[2:]), 3), dtype=np.uint8)
        _, idx = preds.squeeze(0).max(0)
        pos = idx == 0
        pred_img[pos.cpu().numpy()] = [0, 0, 0]
        pos = idx == 1
        pred_img[pos.cpu().numpy()] = [0, 0, 255]
        pos = idx == 2
        pred_img[pos.cpu().numpy()] = [0, 255, 0]
        pos = idx == 3
        pred_img[pos.cpu().numpy()] = [0, 255, 255]
        
        cv2.imwrite(os.path.join(result_path, os.path.basename(img_path[0])), pred_img)


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 208/208 [00:48<00:00,  4.27it/s]


### Run-Length Encoding

In [30]:
import pandas as pd

In [31]:
def mask2rle(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formatted
    '''
    pixels= img.T.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [32]:
test_label_file_list = os.listdir(result_path)
test_label_path_list = [os.path.join(result_path, x) for x in test_label_file_list]

In [33]:
rle_list = []
for file_path in test_label_path_list:
    img = cv2.imread(file_path)
    rle = mask2rle(img)
    rle_list.append(rle)

In [34]:
my_dict = {'Image_Label':test_label_file_list, 'EncodedPixels':rle_list}

In [35]:
my_df = pd.DataFrame(my_dict)

In [36]:
my_df.to_csv(os.path.join(workspace_path, 'submission_deepv3.csv'), index=False)