In [1]:
import os
import cv2
from PIL import Image
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
#from torch.optim.lr_scheduler import _LRScheduler

from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2

from torchvision import models
from torchsummary import summary
import torch.nn.functional as F


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# GPU 사용이 가능할 경우, GPU를 사용할 수 있게 함.'
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
print(device)

print(os.environ.get('CUDA_VISIBLE_DEVICES'))

cuda
3


In [3]:
# RLE 인코딩 함수
def rle_encode(mask):
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

# 클래스별 IoU를 계산하기 위한 함수
def calculate_iou_per_class(y_true, y_pred, class_id):
    intersection = np.sum((y_true == class_id) & (y_pred == class_id))
    union = np.sum((y_true == class_id) | (y_pred == class_id))
    iou = intersection / union if union > 0 else 0
    return iou

def calculate_ious(y_true, y_pred, class_num):
    ious = []
    for class_id in range(class_num):
        intersection = np.sum((y_true == class_id) & (y_pred == class_id))
        union = np.sum((y_true == class_id) | (y_pred == class_id))
        iou = intersection / union if union > 0 else 0
        ious.append(iou)
    return ious

def apply_fisheye_distortion(images, masks, label):
    # 이미지 크기 가져오기
    batch, channel, height, width = images.shape

    # 카메라 매트릭스 생성
    focal_length = width / 4
    center_x = width / 2
    center_y = height / 2
    camera_matrix = np.array([[focal_length, 0, center_x],
                              [0, focal_length, center_y],
                              [0, 0, 1]], dtype=np.float32)

    # 왜곡 계수 생성
    dist_coeffs = np.array([0, 0.02 * label, 0, 0], dtype=np.float32)

    # 왜곡 보정
    undistorted_images = []
    undistorted_masks = []

    for i in range(batch):
        image = images[i].permute(1, 2, 0).cpu().numpy()  # 텐서를 NumPy 배열로 변환
        mask = masks[i].cpu().numpy()
        undistorted_image = cv2.undistort(image, camera_matrix, dist_coeffs)
        undistorted_mask = cv2.undistort(mask, camera_matrix, dist_coeffs)
        undistorted_mask = np.round(undistorted_mask).astype(np.uint8)
        undistorted_mask[undistorted_mask > 12] = 12

        # 다시 텐서로 변환
        undistorted_image = torch.from_numpy(undistorted_image).permute(2, 0, 1).float().to(device)
        undistorted_mask = torch.from_numpy(undistorted_mask).long().to(device)

        undistorted_images.append(undistorted_image)
        undistorted_masks.append(undistorted_mask)

    undistorted_images = torch.stack(undistorted_images, dim=0)
    undistorted_masks = torch.stack(undistorted_masks, dim=0)

    return undistorted_images, undistorted_masks

## 데이터셋

In [4]:
class CustomDataset(Dataset):
    def __init__(self, csv_file, transform=None, infer=False, CL=0):
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.infer = infer
        self.CL = CL

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        #directory_path = "/mnt/nas27/Dataset/Samsung_DM"
        directory_path = './data/224'
        img_path = self.data.iloc[idx, 1]
        img_path = os.path.join(directory_path, img_path)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        #image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        
        if self.infer:
            if self.transform:
                image = self.transform(image=image)['image']
            return image
        
        mask_path = self.data.iloc[idx, 2]
        mask_path = os.path.join(directory_path, mask_path)
        mask = cv2.imread(mask_path)
        #mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
        if self.CL == 0:
            mask = np.round(mask).astype(np.uint8)
            mask[mask > 12] = 12 #배경을 픽셀값 12로 간주
            mask += 1
            mask[mask == 13] = 0

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask
     

transform = A.Compose(
    [   
        #A.Resize(224, 224),
        #A.Resize(128, 128),
        #A.Normalize(),
        A.GaussNoise(var_limit=(10.0, 30.0), p=0.3),
        
        # 변형
        A.HorizontalFlip(p=0.5),
        # A.RandomRotate90(p=0.5),
        # A.HueSaturationValue(p=0.2),
        A.ColorJitter(brightness=0.7, contrast=0.7, saturation=0.7),
        
        ToTensorV2()
    ]
)

totensor = A.Compose(
    [   
        #A.Resize(224, 224),
        #A.Resize(128, 128),
        #A.Normalize(),
        #A.GaussNoise(var_limit=(10.0, 30.0), p=0.3),
        
        # 변형
        #A.HorizontalFlip(p=0.5),
        # A.RandomRotate90(p=0.5),
        # A.HueSaturationValue(p=0.2),
        #A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
        
        ToTensorV2()
    ]
)

## 모델 선언

In [5]:
from init_weights import init_weights

class GradReverse(torch.autograd.Function):
    @staticmethod
    def forward(self, x):
        return x.view_as(x)
    @staticmethod
    def backward(self, grad_output): # 역전파 시에 gradient에 음수를 취함
        return grad_output * (-1)

class domain_classifier(nn.Module):
    def __init__(self):
        super(domain_classifier, self).__init__()
        self.fc1 = nn.Linear(224*224*64, 50)
        self.dropout = nn.Dropout2d(0.5) #
        self.fc2 = nn.Linear(50, 4) # source = 0, target = 1 회귀 가정

    def forward(self, x):
        x = x.view(-1, 224*224*64)
        x = self.dropout(x)
        x = GradReverse.apply(x) # gradient reverse
        x = F.leaky_relu(self.fc1(x))
        x = self.fc2(x)
        
        #return torch.sigmoid(x)
        return x


class Unet_block(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(mid_channels)
        self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)        
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        return out


class Nested_UNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, deep_supervision=False):
        super().__init__()

        num_filter = [32, 64, 128, 256, 512]
        self.deep_supervision = deep_supervision
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        # DownSampling
        self.conv0_0 = Unet_block(input_channels, num_filter[0], num_filter[0])
        self.conv1_0 = Unet_block(num_filter[0], num_filter[1], num_filter[1])
        self.conv2_0 = Unet_block(num_filter[1], num_filter[2], num_filter[2])
        self.conv3_0 = Unet_block(num_filter[2], num_filter[3], num_filter[3])
        self.conv4_0 = Unet_block(num_filter[3], num_filter[4], num_filter[4])

        # Upsampling & Dense skip
        # N to 1 skip
        self.conv0_1 = Unet_block(num_filter[0] + num_filter[1], num_filter[0], num_filter[0])
        self.conv1_1 = Unet_block(num_filter[1] + num_filter[2], num_filter[1], num_filter[1])
        self.conv2_1 = Unet_block(num_filter[2] + num_filter[3], num_filter[2], num_filter[2])
        self.conv3_1 = Unet_block(num_filter[3] + num_filter[4], num_filter[3], num_filter[3])
       
        # N to 2 skip
        self.conv0_2 = Unet_block(num_filter[0]*2 + num_filter[1], num_filter[0], num_filter[0])
        self.conv1_2 = Unet_block(num_filter[1]*2 + num_filter[2], num_filter[1], num_filter[1])
        self.conv2_2 = Unet_block(num_filter[2]*2 + num_filter[3], num_filter[2], num_filter[2])

        # N to 3 skip
        self.conv0_3 = Unet_block(num_filter[0]*3 + num_filter[1], num_filter[0], num_filter[0])
        self.conv1_3 = Unet_block(num_filter[1]*3 + num_filter[2], num_filter[1], num_filter[1])

        # N to 4 skip
        self.conv0_4 = Unet_block(num_filter[0]*4 + num_filter[1], num_filter[0], num_filter[0])

        self.domain_classifier = domain_classifier()
        
        if self.deep_supervision:
            self.output1 = nn.Conv2d(num_filter[0], num_classes, kernel_size=1)
            self.output2 = nn.Conv2d(num_filter[0], num_classes, kernel_size=1)
            self.output3 = nn.Conv2d(num_filter[0], num_classes, kernel_size=1)
            self.output4 = nn.Conv2d(num_filter[0], num_classes, kernel_size=1)

        else:
            self.output = nn.Conv2d(num_filter[0], num_classes, kernel_size=1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm2d):
                init_weights(m, init_type='kaiming')

    def forward(self, x):                    # (Batch, 3, 256, 256)

        x0_0 = self.conv0_0(x)               
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], dim=1))
        
        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], dim=1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], dim=1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], dim=1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], dim=1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], dim=1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], dim=1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], dim=1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], dim=1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], dim=1))

        if self.deep_supervision:
            output1 = self.output1(x0_1)
            output2 = self.output2(x0_2)
            output3 = self.output3(x0_3)
            output4 = self.output4(x0_4)
            output = (output1 + output2 + output3 + output4) / 4
            x_d = self.domain_classifier(x0_4)
        else:
            output = self.output(x0_4)
            x_d = self.domain_classifier(x0_4)

        return output, x_d
        

## Loss Function

In [6]:
# loss function과 optimizer 정의
class DANN_Loss(nn.Module):
    def __init__(self):
        super(DANN_Loss, self).__init__()
        self.CE = nn.CrossEntropyLoss()

    def forward(self, result, label, domain_num, alpha=1):
        label_logits, domain_logits = result

        batch_size = domain_logits.shape[0]

        segment_loss = self.CE(label_logits, label)
        
        domain_target = torch.LongTensor([domain_num] * batch_size).to(device)
        domain_loss = self.CE(domain_logits, domain_target)
        
        loss = segment_loss + alpha * domain_loss

        return loss, segment_loss, domain_loss
    
    

class DANN_Loss_mse(nn.Module):
    def __init__(self):
        super(DANN_Loss_mse, self).__init__()
        self.CE = nn.CrossEntropyLoss()
        self.MSE = nn.MSELoss()

    def forward(self, result, label, domain_num, alpha = 1):
        label_logits, domain_logits = result # DANN_CNN의 결과

        batch_size = domain_logits.shape[0]

        segment_loss = self.CE(label_logits, label) # class 분류 loss
        
        domain_target = torch.LongTensor([domain_num] * batch_size).unsqueeze(1).to(device).float()      
        domain_loss = self.MSE(domain_logits, domain_target) # domain 분류 loss
        domain_loss = torch.sqrt(domain_loss)
        loss = segment_loss + alpha * domain_loss
        # print("segment_mask : ", label.shape)
        # print("domain_answer : ", domain_target.shape)
        return loss ,segment_loss, domain_loss


loss_fn = DANN_Loss().to(device)
#loss_fn = DANN_Loss_mse().to(device)

#criterion =nn.CrossEntropyLoss()
#domain_criterion = nn.BCELoss()
#criterion = nn.CrossEntropyLoss(weight=class_weights)

In [7]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

## Parameters , Train

In [8]:
LR = 0.0001
EP = 30
BATCH_SIZE = 16
ACCMULATION_STEP = 1 
N_CLASSES = 13
alpha = 0.01
N_LABELS = 4
# model 초기화


In [9]:

#모델 선언
model = Nested_UNet(num_classes = N_CLASSES,input_channels=3, deep_supervision=False).to(device)
#model.load_state_dict(torch.load('./data/resnet34_1118_1.pth'), strict=False)
#옵티마이저 설정
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
optimizer.zero_grad() 
#데이터 로더
source_dataset = CustomDataset(csv_file='./data/896_csv/train_source.csv', transform=transform)
source_dataloader = DataLoader(source_dataset, batch_size=BATCH_SIZE, shuffle=True)
#소스변형
val_source_dataset = CustomDataset(csv_file="./data/896_csv/val_source.csv", transform=transform)
val_source_dataloader = DataLoader(val_source_dataset, batch_size=BATCH_SIZE, shuffle=False)
#실제 타겟
val_target_dataset = CustomDataset(csv_file="./data/896_csv/val_source_CL.csv", transform=transform)
val_target_dataloader = DataLoader(val_target_dataset, batch_size=BATCH_SIZE, shuffle=False)
# 모델의 파라미터 수 계산
total_params = sum(p.numel() for p in model.parameters())
print(f"총 파라미터 수: {total_params}")

총 파라미터 수: 214684945


miou test

In [10]:
for epoch in range(EP):
    model.train()
    epoch_loss = 0
    seg_losses = 0
    domain_losses = 0
    #alpha = 0.0045
    train_class_ious = [[],[],[],[]]
    
    # label_dfs = {label: None for label in range(N_LABELS)}
    # print(alpha)
    for source_images, source_masks in tqdm(source_dataloader,desc=f"Epoch: {epoch+1}"):
        for label in range(N_LABELS): 
            source_image, source_mask = apply_fisheye_distortion(source_images, source_masks, label)
            source_image = source_image.float().to(device)
            source_mask = source_mask.long().to(device)
            source_outputs = model(source_image)
            
            optimizer.zero_grad()
            target_loss, seg_loss, domain_loss = loss_fn(source_outputs, source_mask, label, alpha = alpha)
            epoch_loss += target_loss.item()
            seg_losses +=  seg_loss.item()
            domain_losses += domain_loss.item()
            target_loss.backward()
            optimizer.step()
            # miou측정
            source_outputs = model(source_image)
            source_outputs = torch.softmax(source_outputs[0], dim=1).cpu()
            source_outputs = torch.argmax(source_outputs, dim=1).numpy()
            
            for class_id in range(N_CLASSES):
                iou = calculate_iou_per_class(np.array(source_mask.cpu()), np.array(source_outputs), class_id)
                train_class_ious[label].append(iou)
            
    for i in range(N_LABELS):
        buff = np.array(train_class_ious[i]).reshape(-1,N_CLASSES)
        buff = np.mean(buff, axis=0)
        print(f"\nLabel_{i}: IoU Scores Train")
        for class_id, iou in enumerate(buff):
            print(f'Class{class_id:02d}: {iou:.4f}', end=" ")
            if (class_id+1) % 7 == 0:
                print()   
    print()    
    print(f"Train seg Loss: {(seg_losses/(N_LABELS*len(source_dataloader)))}",f"Train dom Loss: {(domain_losses/(N_LABELS*len(source_dataloader)))}")
    print(f"Train Loss: {(epoch_loss/(N_LABELS*len(source_dataloader)))}")
    print(f"Train mIoU: {np.mean(train_class_ious)}" )
    ################################################################
    # 클래스별 IoU를 누적할 리스트 초기화
    val_class_ious = []
    fish_val_class_ious = []
    val_epoch_loss = 0
    val_seg_loss = 0
    val_domain_loss = 0
    # valid
    model.eval()
    with torch.no_grad():
        for target_images, target_masks in tqdm(val_source_dataloader):
            label = 2.5
            target_images, target_masks = apply_fisheye_distortion(target_images, target_masks, label)
            target_images = target_images.float().to(device)
            target_masks = target_masks.long().to(device)

            target_outputs = model(target_images)

            target_loss, target_seg_loss, target_domain_loss = loss_fn(target_outputs, target_masks, label, alpha = alpha)

            val_seg_loss +=  target_seg_loss.item()
            val_domain_loss += target_domain_loss.item()
            
            loss = target_loss

            val_epoch_loss += loss.item()

            # train 클래스별 IoU 계산
            target_outputs = torch.softmax(target_outputs[0], dim=1).cpu()
            target_outputs = torch.argmax(target_outputs, dim=1).numpy()

            for class_id in range(N_CLASSES):
                iou = calculate_iou_per_class(np.array(target_masks.cpu()), np.array(target_outputs), class_id)
                fish_val_class_ious.append(iou)

    fish_val_class_ious = np.array(fish_val_class_ious).reshape(-1,N_CLASSES)
    fish_val_class_ious = np.mean(fish_val_class_ious, axis=0)
    print()
    print("--IoU Scores Source val--")
    for class_id, iou in enumerate(fish_val_class_ious):
        print(f'Class{class_id}: {iou:.4f}', end=" ")
        if (class_id+1) % 7 == 0:
            print()

    # mIoU 계산
    fish_val_mIoU = np.mean(fish_val_class_ious)

    # 에폭마다 결과 출력 
    print(f"\nEpoch{epoch+1}")
    print(f"Valid_Seg Loss: {(val_seg_loss/len(val_source_dataloader))}",f"Valid_dom Loss: {(val_domain_loss/len(val_source_dataloader))}")
    print(f"Valid Source Loss: {(val_epoch_loss/len(val_source_dataloader))}")
    print(f"Valid Source mIoU: {fish_val_mIoU}" )
    print("___________________________________________________________________________________________\n")
    val_class_ious = []
    fish_val_class_ious = []
    val_epoch_loss = 0
    val_seg_loss = 0
    val_domain_loss = 0
    with torch.no_grad():
        for target_images, target_masks in tqdm(val_target_dataloader):
            label = 2.5
            #target_images, target_masks = apply_fisheye_distortion(target_images, target_masks, label)
            target_images = target_images.float().to(device)
            target_masks = target_masks.long().to(device)

            target_outputs = model(target_images)

            target_loss, target_seg_loss, target_domain_loss = loss_fn(target_outputs, target_masks, label, alpha = alpha)

            val_seg_loss +=  target_seg_loss.item()
            val_domain_loss += target_domain_loss.item()
            
            loss = target_loss

            val_epoch_loss += loss.item()

            # train 클래스별 IoU 계산
            target_outputs = torch.softmax(target_outputs[0], dim=1).cpu()
            target_outputs = torch.argmax(target_outputs, dim=1).numpy()

            for class_id in range(N_CLASSES):
                iou = calculate_iou_per_class(np.array(target_masks.cpu()), np.array(target_outputs), class_id)
                fish_val_class_ious.append(iou)

    fish_val_class_ious = np.array(fish_val_class_ious).reshape(-1,N_CLASSES)
    fish_val_class_ious = np.mean(fish_val_class_ious, axis=0)
    print()
    print("--IoU Scores Target val--")
    for class_id, iou in enumerate(fish_val_class_ious):
        print(f'Class{class_id}: {iou:.4f}', end=" ")
        if (class_id+1) % 7 == 0:
            print()

    # mIoU 계산
    fish_val_mIoU = np.mean(fish_val_class_ious)

    # 에폭마다 결과 출력 
    print(f"\nEpoch{epoch+1}")
    print(f"Valid_Seg Loss: {(val_seg_loss/len(val_target_dataloader))}",f"Valid_dom Loss: {(val_domain_loss/len(val_target_dataloader))}")
    print(f"Valid Target Loss: {(val_epoch_loss/len(val_target_dataloader))}")
    print(f"Valid Target mIoU: {fish_val_mIoU}" )
    print("___________________________________________________________________________________________\n")

print("Hyperparamerters")
print(f"LR = {LR} | EP = {EP}, BATCH_SIZE = {BATCH_SIZE}, N_CLASSES = {N_CLASSES}, init_alpha = {alpha}, N_LABELS = {N_LABELS}")

Epoch: 1:   0%|          | 0/138 [00:00<?, ?it/s]

Epoch: 1:   0%|          | 0/138 [00:01<?, ?it/s]


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), './data/unet++_1121_1.pth')
