In [1]:
import torch

import cv2
import wandb
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn.functional as F
from datasets import load_metric

from torch import nn
import torch.nn as nn
from torchmetrics.classification import BinaryJaccardIndex
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from glob import glob
import torch.optim as optim
import pytorch_lightning as pl
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor
from torch.optim.lr_scheduler import CosineAnnealingLR

In [2]:
class segDataset(Dataset):
    def __init__(self, data_path, transforms=None):
        # cata path 설정 잘하기
        self.pos_imgs = sorted(glob(data_path + 'Positive/Image/*'))
        self.pos_labels = sorted(glob(data_path + 'Positive/Label/*'))
        self.neg_imgs = sorted(glob(data_path + 'Negative/Image/*'))
        self.neg_labels = sorted(glob(data_path + 'Negative/Label /*'))
        # positive와 negative를 합쳐서 불러오는 코드를 작성
        self.imgs = self.pos_imgs + self.neg_imgs
        self.labels = self.pos_labels + self.neg_labels
        self.transforms = transforms
        
    def __len__(self):
        return len(self.labels)
        
    def __getitem__(self, item):
        img_path = self.imgs[item]
        label_path = self.labels[item]
        img = cv2.imread(img_path)
        label = cv2.imread(label_path, cv2.IMREAD_UNCHANGED)
        label = np.expand_dims(label, axis=2)

        concat = np.concatenate([img, label], axis=2)
        concat = torch.from_numpy(concat)
        concat = concat.permute(2,0,1) # (h,w,c -> c,h,w)

        
        if self.transforms:
            imgs = self.transforms(concat)

        X = imgs[:3].to(torch.float32)
        y = imgs[3].to(torch.float32)
            
        return {'X' : X/255, 'y': y/255}



data_path = './data/Train/'
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    # transforms.ToTensor(), # (h,w,c -> c,h,w) + Normalize
])

training_data = segDataset(data_path=data_path, transforms=transform)

total_samples = len(training_data)
train_size = int(0.8 * total_samples)
val_size = total_samples - train_size

# 인덱스를 무작위로 섞음
indices = list(range(total_samples))
np.random.shuffle(indices)
train_sampler = SubsetRandomSampler(indices[:train_size])
val_sampler = SubsetRandomSampler(indices[train_size:])


# DataLoader 설정
train_dataloader = DataLoader(training_data, batch_size=16, sampler=train_sampler)
val_dataloader = DataLoader(training_data, batch_size=16, sampler=val_sampler)

In [3]:
class ASPP(nn.Module):
    def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
        super(ASPP, self).__init__()
        
        self.aspp_block1 = nn.Sequential(
            nn.Conv2d(
                in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_dims),
        )
        self.aspp_block2 = nn.Sequential(
            nn.Conv2d(
                in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_dims),
        )
        self.aspp_block3 = nn.Sequential(
            nn.Conv2d(
                in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_dims),
        )

        self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
        self._init_weights()

    def forward(self, x):
        x1 = self.aspp_block1(x)
        x2 = self.aspp_block2(x)
        x3 = self.aspp_block3(x)
        out = torch.cat([x1, x2, x3], dim=1)
        return self.output(out)

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)
                
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
           
        self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
                               nn.ReLU(),
                               nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride, padding):
        super(ResidualBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
        )
        self.shortcut = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1),
        ) 
    
    def forward(self, inputs):
        r = self.conv_block(inputs)
        s = self.shortcut(inputs)
        
        skip = r + s
        return skip
    
class RPAResUNet(pl.LightningModule):
    def __init__(self, num_classes):
        super(RPAResUNet, self).__init__()
        self.jaccard = BinaryJaccardIndex()
        self.num_classes = num_classes
        
        """ Input block -> residual layer """
        self.residual_1_a = self.input_block(in_channels=3, out_channels=32)
        self.residual_1_b = self.input_skip(in_channels=3, out_channels=32)
        # input block -> spatial attention
        self.spatial_attention = SpatialAttention()
        
        """ Residual block x 3, Encoding """
        self.residual_2 = ResidualBlock(32, 64, 2, 1)
        self.residual_3 = ResidualBlock(64, 128, 2, 1)
        self.residual_4 = ResidualBlock(128, 256, 2, 1)
        
        """ Bridge block, ASPP """
        # 3 -> 32, 1024 -> 512, computation cost
        self.bridge_aspp = ASPP(256, 512)
        
        """ Last Encoder layer <- channel attention """
        self.channel_attention = ChannelAttention(512)
        
        """ Residual block x3, Decoding 1st block """
        # upsample out_feature 512 -> 256 ?
        self.upsample_block_1 = nn.ConvTranspose2d(in_channels=512, out_channels=512, 
                                          kernel_size=2, stride=2, padding=0)
        self.residual_5 = ResidualBlock(512+128, 256, 1, 1) # upsampling + residual_3
        
        """ 2nd Residual block, Decoder """
        self.upsample_block_2 = nn.ConvTranspose2d(in_channels=256, out_channels=256, 
                                          kernel_size=2, stride=2, padding=0)
        self.residual_6 = ResidualBlock(256+64, 128, 1, 1)
        
        """ 3rd Residual block, Decoder """
        self.upsample_block_3 = nn.ConvTranspose2d(in_channels=128, out_channels=128, 
                                          kernel_size=2, stride=2, padding=0)
        self.residual_7 = ResidualBlock(128+32, 64, 1, 1)
        
        """ output block """
        self.output_aspp = ASPP(64, 32)
        self.output = nn.Sequential(
          nn.Conv2d(in_channels=32, out_channels=num_classes, kernel_size=1),
          nn.Sigmoid(),
        )
        
    def input_block(self, in_channels, out_channels):
        block = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
                                    nn.BatchNorm2d(num_features=out_channels),
                                    nn.ReLU(),
                                    nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
                                    )
        return block
    
    def input_skip(self, in_channels, out_channels):
        skip = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        return skip 

    def forward(self, X):
        residual_1_a_out = self.residual_1_a(X)
        residual_1_b_out = self.residual_1_b(X)
        residual_1_out = residual_1_a_out + residual_1_b_out
        spatial_attention_out = self.spatial_attention(residual_1_out) * residual_1_out
        
        residual_2_out = self.residual_2(spatial_attention_out)
        residual_3_out = self.residual_3(residual_2_out)
        residual_4_out = self.residual_4(residual_3_out)
        
        bridge_aspp_a = self.bridge_aspp(residual_4_out)
        channel_attention_out = self.channel_attention(bridge_aspp_a) * bridge_aspp_a
        
        """ decoder block without attention modules """
        
        upsample_1_out = self.upsample_block_1(channel_attention_out)
        residual_5_out = self.residual_5(torch.cat((upsample_1_out, residual_3_out), dim=1))
        
        upsample_2_out = self.upsample_block_2(residual_5_out)
        residual_6_out = self.residual_6(torch.cat((upsample_2_out, residual_2_out), dim=1))
        
        upsample_3_out = self.upsample_block_3(residual_6_out)
        residual_7_out = self.residual_7(torch.cat((upsample_3_out, residual_1_out), dim=1))
        
        output_aspp_b = self.output_aspp(residual_7_out)
        out = self.output(output_aspp_b)
        return out

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               mode='min',	# Loss최소화,최대화 결정
                                                               factor=1e-8,	# 학습률 감소율
                                                               patience=5,
                                                               verbose=True,)
        monitor_metric = 'val_loss'
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': monitor_metric  # Specify the metric you want to monitor
            }
        }

    def training_step(self, batch, batch_idx):
        x, y = batch['X'], batch['y']
        outputs = self(x).squeeze(dim=1)
        loss = nn.BCELoss()(outputs, y)
        
        predicted_masks = (outputs > 0.5).to(torch.uint8) # 0아니면 1로 바꾸는
        y = y.to(torch.uint8)
        
        iou = self.jaccard(predicted_masks, y)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_iou', iou, on_step=True, on_epoch=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch['X'], batch['y']
        outputs = self(x).squeeze(dim=1)
        loss = nn.BCELoss()(outputs, y)
        
        predicted_masks = (outputs > 0.5).to(torch.uint8)
        y = y.to(torch.uint8)
  
        iou = self.jaccard(predicted_masks, y)

        self.log('val_loss', loss, prog_bar=True)
        self.log('val_iou', iou, prog_bar=True)

        return loss
        


In [None]:
# WandbLogger를 초기화
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

wandb.login(key='eed81e1c0a41dd8dd67a4ca90cea1be5a06d4eb0')
wandb_logger = WandbLogger(project='AISW', entity='hcim', name='RPAResUNet')

model = RPAResUNet(num_classes=1)

early_stopping_callback = EarlyStopping(
    monitor='val_loss',
    patience=8,
    verbose=True,
    mode='min'
)

trainer = pl.Trainer(devices=[0],max_epochs=10, logger=wandb_logger, callbacks=[early_stopping_callback])
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mchan4im[0m ([33mhcim[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

   | Name              | Type               | Params
----------------------------------------------------------
0  | jaccard           | BinaryJaccardIndex | 0     
1  | residual_1_a      | Sequential         | 10.2 K
2  | residual_1_b      | Conv2d             | 896   
3  | spatial_attention | SpatialAttention   | 98    
4  | residual_2        | ResidualBlock      | 74.1 K
5  | residual_3        | ResidualBlock      | 295 K 
6  | residual_4        | ResidualBlock      | 1.2 M 
7  | bridge_aspp       | ASPP               | 4.3 M 
8  | channel_attention | ChannelAttention   | 32.8 K
9  | upsample_block_1  | ConvTranspose2d    | 1.0 M 
10 | residual_5        | ResidualBlock      | 3.5 M 
11 | upsample_block_2  | ConvTranspose2d    | 262 K 
12 | residual_6        | ResidualBlock      | 886 K 
13 | upsample_block_3  | ConvTranspose2d    | 65.7 K
14 | residual_7        | ResidualBlock      | 221 K 
15 | output_aspp       | ASPP       

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 0.328


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.218 >= min_delta = 0.0. New best score: 0.110


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.017 >= min_delta = 0.0. New best score: 0.092


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.085


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 0.084


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
"""
class segDataset(Dataset):
    def __init__(self, data_path, transforms=None):
        # cata path 설정 잘하기
        self.pos_imgs = sorted(glob(data_path + 'Positive/Image/*'))
        self.pos_labels = sorted(glob(data_path + 'Positive/Label/*'))
        self.neg_imgs = sorted(glob(data_path + 'Negative/Image/*'))
        self.neg_labels = sorted(glob(data_path + 'Negative/Label /*'))
        # positive와 negative를 합쳐서 불러오는 코드를 작성
        self.imgs = self.pos_imgs + self.neg_imgs
        self.labels = self.pos_labels + self.neg_labels
        self.transforms = transforms
        
    def __len__(self):
        return len(self.labels)
        
    def __getitem__(self, item):
        img_path = self.imgs[item]
        label_path = self.labels[item]
        img = cv2.imread(img_path)
        label = cv2.imread(label_path, cv2.IMREAD_UNCHANGED)
        label = np.expand_dims(label, axis=2)

        concat = np.concatenate([img, label], axis=2)
        concat = torch.from_numpy(concat)
        concat = concat.permute(2,0,1) # (h,w,c -> c,h,w)

        
        if self.transforms:
            imgs = self.transforms(concat)

        X = imgs[:3].to(torch.float32)
        y = imgs[3].to(torch.float32)
            
        return {'X' : X/255, 'y': y/255}

"""

In [None]:
import random
import numpy as np
data_path = './data/Train/'

def set_seed(seed):
    random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
set_seed(7)

In [None]:
def tensor_to_image(tensor):
    tensor = tensor.squeeze()
    numpy_image = tensor.cpu().numpy()
    numpy_image = (numpy_image * 255).astype(np.uint8)
    
    if numpy_image.shape[0] == 1: # 흑백 이미지라면 채널 차원을 제거.
        numpy_image = np.squeeze(numpy_image, axis=0)
        
    elif numpy_image.shape[0] == 3: # RGB 이미지라면 채널 차원을 맨 뒤로 이동.
        numpy_image = np.transpose(numpy_image, (1, 2, 0))
    return numpy_image

In [None]:
import torch
from PIL import Image

model.eval()
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), # (h,w,c -> c,h,w) + Normalize
])
def inference_and_save_images(input_dir, output_dir, start_index, end_index):
    for i in range(start_index, end_index + 1):
        input_path = os.path.join(input_dir, f"{i}.jpg")
        output_path = os.path.join(output_dir, f"{i}.jpg")

        # 파일이 존재하지 않으면 스킵. Test/Positive/8322.jpg가 없음
        if not os.path.exists(input_path):
            print(f"File {input_path} not found. Skipping...")
            continue
            
        # 이미지 불러오기 및 전처리
        img = Image.open(input_path)
        img = transform(img).unsqueeze(0)

        # 추론
        with torch.no_grad():
            output = model(img)

        # 텐서를 이미지로 변환후 저장
        output_image = tensor_to_image(output)
        cv2.imwrite(output_path, output_image)

        print(f"Inferred and saved {input_path} to {output_path}")

# Positive 예측
positive_input_dir = "./data/Test/Positive/"
positive_output_dir = "./data/Prediction/Positive/"
inference_and_save_images(positive_input_dir, positive_output_dir, 8000, 8817)

# Negative 예측
negative_input_dir = "./data/Test/Negative/"
negative_output_dir = "./data/Prediction/Negative/"
inference_and_save_images(negative_input_dir, negative_output_dir, 1001, 1181)

In [None]:
!zip -r Prediction.zip ./data/Prediction

In [None]:
"""
for i in range(x.shape[0]):
    x = x.permute(0, 2, 3, 1)
    cv2.imwrite(f"{batch_idx}_img.jpg", x[i].cpu().numpy()*255)
    cv2.imwrite(f"{batch_idx}_label.jpg", y[i].cpu().numpy()*255)
    cv2.imwrite(f"{batch_idx}_output.jpg", outputs[i].cpu().detach().numpy()*255) 
    """