In [1]:
import torch

import cv2
import wandb
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn.functional as F
from datasets import load_metric

from torch import nn
import torch.nn as nn
from torchmetrics.classification import BinaryJaccardIndex
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from glob import glob
import torch.optim as optim
import pytorch_lightning as pl
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor
from torch.optim.lr_scheduler import CosineAnnealingLR

In [2]:
class segDataset(Dataset):
    def __init__(self, data_path, transforms=None):
        # cata path 설정 잘하기
        self.pos_imgs = sorted(glob(data_path + 'Positive/Image/*'))
        self.pos_labels = sorted(glob(data_path + 'Positive/Label/*'))
        self.neg_imgs = sorted(glob(data_path + 'Negative/Image/*'))
        self.neg_labels = sorted(glob(data_path + 'Negative/Label /*'))
        # positive와 negative를 합쳐서 불러오는 코드를 작성
        self.imgs = self.pos_imgs + self.neg_imgs
        self.labels = self.pos_labels + self.neg_labels
        self.transforms = transforms
        
    def __len__(self):
        return len(self.labels)
        
    def __getitem__(self, item):
        img_path = self.imgs[item]
        label_path = self.labels[item]
        img = cv2.imread(img_path)
        label = cv2.imread(label_path, cv2.IMREAD_UNCHANGED)
        label = np.expand_dims(label, axis=2)

        concat = np.concatenate([img, label], axis=2)
        concat = torch.from_numpy(concat)
        concat = concat.permute(2,0,1) # (h,w,c -> c,h,w)

        
        if self.transforms:
            imgs = self.transforms(concat)

        X = imgs[:3].to(torch.float32)
        y = imgs[3].to(torch.float32)
            
        return {'X' : X/255, 'y': y/255}



data_path = './data/Train/'
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    # transforms.ToTensor(), # (h,w,c -> c,h,w) + Normalize
])

training_data = segDataset(data_path=data_path, transforms=transform)

total_samples = len(training_data)
train_size = int(0.8 * total_samples)
val_size = total_samples - train_size

# 인덱스를 무작위로 섞음
indices = list(range(total_samples))
np.random.shuffle(indices)
train_sampler = SubsetRandomSampler(indices[:train_size])
val_sampler = SubsetRandomSampler(indices[train_size:])


# DataLoader 설정
train_dataloader = DataLoader(training_data, batch_size=16, sampler=train_sampler)
val_dataloader = DataLoader(training_data, batch_size=16, sampler=val_sampler)

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
from modeling.aspp import build_aspp
from modeling.decoder import build_decoder
from modeling.backbone import build_backbone

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

def init_weights(net, init_type='normal', gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)

class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )


    def forward(self,x):
        x = self.conv(x)
        return x

class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		    nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x

class Recurrent_block(nn.Module):
    def __init__(self,ch_out,t=2):
        super(Recurrent_block,self).__init__()
        self.t = t
        self.ch_out = ch_out
        self.conv = nn.Sequential(
            nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		    nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        for i in range(self.t):

            if i==0:
                x1 = self.conv(x)
            
            x1 = self.conv(x+x1)
        return x1
        
class RRCNN_block(nn.Module):
    def __init__(self,ch_in,ch_out,t=2):
        super(RRCNN_block,self).__init__()
        self.RCNN = nn.Sequential(
            Recurrent_block(ch_out,t=t),
            Recurrent_block(ch_out,t=t)
        )
        self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)

    def forward(self,x):
        x = self.Conv_1x1(x)
        x1 = self.RCNN(x)
        return x+x1


class single_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(single_conv,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.conv(x)
        return x

class Attention_block(nn.Module):
    def __init__(self,F_g,F_l,F_int):
        super(Attention_block,self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
            )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)

        return x*psi



In [5]:
class U_Net(pl.LightningModule):
    def __init__(self,img_ch=3,output_ch=1):
        super(U_Net,self).__init__()
        self.jaccard = BinaryJaccardIndex()
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)

        self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
        self.Conv2 = conv_block(ch_in=64,ch_out=128)
        self.Conv3 = conv_block(ch_in=128,ch_out=256)
        self.Conv4 = conv_block(ch_in=256,ch_out=512)
        self.Conv5 = conv_block(ch_in=512,ch_out=1024)

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
        self.sigmoid = nn.Sigmoid()


    def forward(self,x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4,d5),dim=1)
        
        d5 = self.Up_conv5(d5)
        
        d4 = self.Up4(d5)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)
        d1 = self.sigmoid(d1)

        return d1

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               mode='min',	# Loss최소화,최대화 결정
                                                               factor=1e-8,	# 학습률 감소율
                                                               patience=5,
                                                               verbose=True,)
        monitor_metric = 'val_loss'
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': monitor_metric  # Specify the metric you want to monitor
            }
        }

    def training_step(self, batch, batch_idx):
        x, y = batch['X'], batch['y']
        outputs = self(x).squeeze(dim=1)
        loss = nn.BCELoss()(outputs, y)
        
        predicted_masks = (outputs > 0.5).to(torch.uint8) # 0아니면 1로 바꾸는
        y = y.to(torch.uint8)
        
        iou = self.jaccard(predicted_masks, y)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_iou', iou, on_step=True, on_epoch=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch['X'], batch['y']
        outputs = self(x).squeeze(dim=1)
        loss = nn.BCELoss()(outputs, y)
        
        predicted_masks = (outputs > 0.5).to(torch.uint8)
        y = y.to(torch.uint8)
  
        iou = self.jaccard(predicted_masks, y)

        self.log('val_loss', loss, prog_bar=True)
        self.log('val_iou', iou, prog_bar=True)

        return loss


class R2U_Net(pl.LightningModule):
    def __init__(self,img_ch=3,output_ch=1,t=2):
        super(R2U_Net,self).__init__()
        self.jaccard = BinaryJaccardIndex()
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
        self.Upsample = nn.Upsample(scale_factor=2)

        self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)

        self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
        
        self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
        
        self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
        
        self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
        

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
        
        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)

        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
        self.sigmoid = nn.Sigmoid()


    def forward(self,x):
        # encoding path
        x1 = self.RRCNN1(x)

        x2 = self.Maxpool(x1)
        x2 = self.RRCNN2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.RRCNN3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.RRCNN4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.RRCNN5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4,d5),dim=1)
        d5 = self.Up_RRCNN5(d5)
        
        d4 = self.Up4(d5)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_RRCNN4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_RRCNN3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_RRCNN2(d2)

        d1 = self.Conv_1x1(d2)
        d1 = self.sigmoid(d1)

        return d1

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               mode='min',	# Loss최소화,최대화 결정
                                                               factor=1e-8,	# 학습률 감소율
                                                               patience=5,
                                                               verbose=True,)
        monitor_metric = 'val_loss'
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': monitor_metric  # Specify the metric you want to monitor
            }
        }

    def training_step(self, batch, batch_idx):
        x, y = batch['X'], batch['y']
        outputs = self(x).squeeze(dim=1)
        loss = nn.BCELoss()(outputs, y)
        
        predicted_masks = (outputs > 0.5).to(torch.uint8) # 0아니면 1로 바꾸는
        y = y.to(torch.uint8)
        
        iou = self.jaccard(predicted_masks, y)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_iou', iou, on_step=True, on_epoch=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch['X'], batch['y']
        outputs = self(x).squeeze(dim=1)
        loss = nn.BCELoss()(outputs, y)
        
        predicted_masks = (outputs > 0.5).to(torch.uint8)
        y = y.to(torch.uint8)
  
        iou = self.jaccard(predicted_masks, y)

        self.log('val_loss', loss, prog_bar=True)
        self.log('val_iou', iou, prog_bar=True)

        return loss



class AttU_Net(pl.LightningModule):
    def __init__(self,img_ch=3,output_ch=1):
        super(AttU_Net,self).__init__()
        self.jaccard = BinaryJaccardIndex()
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)

        self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
        self.Conv2 = conv_block(ch_in=64,ch_out=128)
        self.Conv3 = conv_block(ch_in=128,ch_out=256)
        self.Conv4 = conv_block(ch_in=256,ch_out=512)
        self.Conv5 = conv_block(ch_in=512,ch_out=1024)

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
        self.sigmoid = nn.Sigmoid()


    def forward(self,x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        x4 = self.Att5(g=d5,x=x4)
        d5 = torch.cat((x4,d5),dim=1)        
        d5 = self.Up_conv5(d5)
        
        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4,x=x3)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3,x=x2)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2,x=x1)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)
        d1 = self.sigmoid(d1)

        return d1

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               mode='min',	# Loss최소화,최대화 결정
                                                               factor=1e-8,	# 학습률 감소율
                                                               patience=5,
                                                               verbose=True,)
        monitor_metric = 'val_loss'
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': monitor_metric  # Specify the metric you want to monitor
            }
        }

    def training_step(self, batch, batch_idx):
        x, y = batch['X'], batch['y']
        outputs = self(x).squeeze(dim=1)
        loss = nn.BCELoss()(outputs, y)
        
        predicted_masks = (outputs > 0.5).to(torch.uint8) # 0아니면 1로 바꾸는
        y = y.to(torch.uint8)
        
        iou = self.jaccard(predicted_masks, y)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_iou', iou, on_step=True, on_epoch=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch['X'], batch['y']
        outputs = self(x).squeeze(dim=1)
        loss = nn.BCELoss()(outputs, y)
        
        predicted_masks = (outputs > 0.5).to(torch.uint8)
        y = y.to(torch.uint8)
  
        iou = self.jaccard(predicted_masks, y)

        self.log('val_loss', loss, prog_bar=True)
        self.log('val_iou', iou, prog_bar=True)

        return loss


class R2AttU_Net(pl.LightningModule):
    def __init__(self,img_ch=3,output_ch=1,t=2):
        super(R2AttU_Net,self).__init__()
        self.jaccard = BinaryJaccardIndex()
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
        self.Upsample = nn.Upsample(scale_factor=2)

        self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)

        self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
        
        self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
        
        self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
        
        self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
        

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
        self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
        
        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
        self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
        self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
        self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)

        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
        self.sigmoid = nn.Sigmoid()


    def forward(self,x):
        # encoding path
        x1 = self.RRCNN1(x)

        x2 = self.Maxpool(x1)
        x2 = self.RRCNN2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.RRCNN3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.RRCNN4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.RRCNN5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        x4 = self.Att5(g=d5,x=x4)
        d5 = torch.cat((x4,d5),dim=1)
        d5 = self.Up_RRCNN5(d5)
        
        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4,x=x3)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_RRCNN4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3,x=x2)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_RRCNN3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2,x=x1)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_RRCNN2(d2)

        d1 = self.Conv_1x1(d2)
        d1 = self.sigmoid(d1)

        return d1

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               mode='min',	# Loss최소화,최대화 결정
                                                               factor=1e-8,	# 학습률 감소율
                                                               patience=5,
                                                               verbose=True,)
        monitor_metric = 'val_loss'
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': monitor_metric  # Specify the metric you want to monitor
            }
        }

    def training_step(self, batch, batch_idx):
        x, y = batch['X'], batch['y']
        outputs = self(x).squeeze(dim=1)
        loss = nn.BCELoss()(outputs, y)
        
        predicted_masks = (outputs > 0.5).to(torch.uint8) # 0아니면 1로 바꾸는
        y = y.to(torch.uint8)
        
        iou = self.jaccard(predicted_masks, y)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_iou', iou, on_step=True, on_epoch=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch['X'], batch['y']
        outputs = self(x).squeeze(dim=1)
        loss = nn.BCELoss()(outputs, y)
        
        predicted_masks = (outputs > 0.5).to(torch.uint8)
        y = y.to(torch.uint8)
  
        iou = self.jaccard(predicted_masks, y)

        self.log('val_loss', loss, prog_bar=True)
        self.log('val_iou', iou, prog_bar=True)

        return loss

In [6]:
# WandbLogger를 초기화
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

wandb.login(key='eed81e1c0a41dd8dd67a4ca90cea1be5a06d4eb0')
wandb_logger = WandbLogger(project='AISW', entity='hcim', name='R2U_Net')

# U_Net , R2U_Net , AttU_Net , R2AttU_Net
model = R2U_Net()

early_stopping_callback = EarlyStopping(
    monitor='val_loss',
    patience=8,
    verbose=True,
    mode='min'
)

trainer = pl.Trainer(devices=[0],max_epochs=100, logger=wandb_logger, callbacks=[early_stopping_callback])
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mchan4im[0m ([33mhcim[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

   | Name      | Type               | Params
--------------------------------------------------
0  | jaccard   | BinaryJaccardIndex | 0     
1  | Maxpool   | MaxPool2d          | 0     
2  | Upsample  | Upsample           | 0     
3  | RRCNN1    | RRCNN_block        | 74.4 K
4  | RRCNN2    | RRCNN_block        | 304 K 
5  | RRCNN3    | RRCNN_block        | 1.2 M 
6  | RRCNN4    | RRCNN_block        | 4.9 M 
7  | RRCNN5    | RRCNN_block        | 19.4 M
8  | Up5       | up_conv            | 4.7 M 
9  | Up_RRCNN5 | RRCNN_block        | 5.2 M 
10 | Up4       | up_conv            | 1.2 M 
11 | Up_RRCNN4 | RRCNN_block        | 1.3 M 
12 | Up3       | up_conv            | 295 K 
13 | Up_RRCNN3 | RRCNN_block        | 328 K 
14 | Up2       | up_conv            | 73.9 K
15 | Up_RRCNN2 | RRCNN_block        | 82.4 K
16 | Conv_1x1  | Conv2d             | 65    
17 | sigmoid   | Sigmoid            | 0     
-------------------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 0.146


Validation: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [7]:
"""
class segDataset(Dataset):
    def __init__(self, data_path, transforms=None):
        # cata path 설정 잘하기
        self.pos_imgs = sorted(glob(data_path + 'Positive/Image/*'))
        self.pos_labels = sorted(glob(data_path + 'Positive/Label/*'))
        self.neg_imgs = sorted(glob(data_path + 'Negative/Image/*'))
        self.neg_labels = sorted(glob(data_path + 'Negative/Label /*'))
        # positive와 negative를 합쳐서 불러오는 코드를 작성
        self.imgs = self.pos_imgs + self.neg_imgs
        self.labels = self.pos_labels + self.neg_labels
        self.transforms = transforms
        
    def __len__(self):
        return len(self.labels)
        
    def __getitem__(self, item):
        img_path = self.imgs[item]
        label_path = self.labels[item]
        img = cv2.imread(img_path)
        label = cv2.imread(label_path, cv2.IMREAD_UNCHANGED)
        label = np.expand_dims(label, axis=2)

        concat = np.concatenate([img, label], axis=2)
        concat = torch.from_numpy(concat)
        concat = concat.permute(2,0,1) # (h,w,c -> c,h,w)

        
        if self.transforms:
            imgs = self.transforms(concat)

        X = imgs[:3].to(torch.float32)
        y = imgs[3].to(torch.float32)
            
        return {'X' : X/255, 'y': y/255}

"""

"\nclass segDataset(Dataset):\n    def __init__(self, data_path, transforms=None):\n        # cata path 설정 잘하기\n        self.pos_imgs = sorted(glob(data_path + 'Positive/Image/*'))\n        self.pos_labels = sorted(glob(data_path + 'Positive/Label/*'))\n        self.neg_imgs = sorted(glob(data_path + 'Negative/Image/*'))\n        self.neg_labels = sorted(glob(data_path + 'Negative/Label /*'))\n        # positive와 negative를 합쳐서 불러오는 코드를 작성\n        self.imgs = self.pos_imgs + self.neg_imgs\n        self.labels = self.pos_labels + self.neg_labels\n        self.transforms = transforms\n        \n    def __len__(self):\n        return len(self.labels)\n        \n    def __getitem__(self, item):\n        img_path = self.imgs[item]\n        label_path = self.labels[item]\n        img = cv2.imread(img_path)\n        label = cv2.imread(label_path, cv2.IMREAD_UNCHANGED)\n        label = np.expand_dims(label, axis=2)\n\n        concat = np.concatenate([img, label], axis=2)\n        concat = torc

In [8]:
import random
import numpy as np
data_path = './data/Train/'

def set_seed(seed):
    random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
set_seed(7)

In [9]:
def tensor_to_image(tensor):
    tensor = tensor.squeeze()
    numpy_image = tensor.cpu().numpy()
    numpy_image = (numpy_image * 255).astype(np.uint8)
    
    if numpy_image.shape[0] == 1: # 흑백 이미지라면 채널 차원을 제거.
        numpy_image = np.squeeze(numpy_image, axis=0)
        
    elif numpy_image.shape[0] == 3: # RGB 이미지라면 채널 차원을 맨 뒤로 이동.
        numpy_image = np.transpose(numpy_image, (1, 2, 0))
    return numpy_image

In [10]:
import torch
from PIL import Image

model.eval()
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), # (h,w,c -> c,h,w) + Normalize
])
def inference_and_save_images(input_dir, output_dir, start_index, end_index):
    for i in range(start_index, end_index + 1):
        input_path = os.path.join(input_dir, f"{i}.jpg")
        output_path = os.path.join(output_dir, f"{i}.jpg")

        # 파일이 존재하지 않으면 스킵. Test/Positive/8322.jpg가 없음
        if not os.path.exists(input_path):
            print(f"File {input_path} not found. Skipping...")
            continue
            
        # 이미지 불러오기 및 전처리
        img = Image.open(input_path)
        img = transform(img).unsqueeze(0)

        # 추론
        with torch.no_grad():
            output = model(img)

        # 텐서를 이미지로 변환후 저장
        output_image = tensor_to_image(output)
        cv2.imwrite(output_path, output_image)

        print(f"Inferred and saved {input_path} to {output_path}")

# Positive 예측
positive_input_dir = "./data/Test/Positive/"
positive_output_dir = "./data/Prediction/Positive/"
inference_and_save_images(positive_input_dir, positive_output_dir, 8000, 8817)

# Negative 예측
negative_input_dir = "./data/Test/Negative/"
negative_output_dir = "./data/Prediction/Negative/"
inference_and_save_images(negative_input_dir, negative_output_dir, 1001, 1181)


KeyboardInterrupt



In [None]:
!zip -r Prediction.zip ./data/Prediction

In [None]:
"""
for i in range(x.shape[0]):
    x = x.permute(0, 2, 3, 1)
    cv2.imwrite(f"{batch_idx}_img.jpg", x[i].cpu().numpy()*255)
    cv2.imwrite(f"{batch_idx}_label.jpg", y[i].cpu().numpy()*255)
    cv2.imwrite(f"{batch_idx}_output.jpg", outputs[i].cpu().detach().numpy()*255) 
    """