In [3]:

from PIL import Image
import requests
import matplotlib.pyplot as plt
import torch.nn as nn
from torchinfo import summary
import torch
import numpy as np
import torch.optim as optim
import torchvision.transforms as T

import segmentation_models_pytorch as smp
import torch.nn.functional as F
from collections import defaultdict
import cv2
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")
tf=T.ToTensor()


In [None]:
params={'image_size':512,
        'lr':2e-4,
        'beta1':0.5,
        'beta2':0.999,
        'batch_size':16,
        'epochs':500,}

In [None]:
image1=np.load('../../data/cv0_ori.npy')
# image2=np.load('../../data/cv1_ori.npy')
# image3=np.load('../../data/cv2_ori.npy')
# image4=np.load('../../data/cv3_ori.npy')
# image5=np.load('../../data/cv4_ori.npy')
mask1=np.load('../../data/cv0_mask.npy')
# mask2=np.load('../../data/cv1_mask.npy')
# mask3=np.load('../../data/cv2_mask.npy')
# mask4=np.load('../../data/cv3_mask.npy')
# mask5=np.load('../../data/cv4_mask.npy')

In [None]:
class CustomDataset(Dataset):
    def __init__(self, image_list, label_list):
        self.img_path = image_list
        self.label = label_list

    def __len__(self):
        return len(self.img_path)

    def __getitem__(self, idx):
        image_path = self.img_path[idx]
        image_path=tf(cv2.cvtColor(image_path, cv2.COLOR_GRAY2RGB))
        
        label_path = self.label[idx]
        label_path = tf(cv2.resize(label_path, (128, 128)))
       
        return image_path, label_path

train_dataset = CustomDataset(image1, mask1)

val_dataset = CustomDataset(image1, mask1)
train_dataloader = DataLoader(
    train_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True)
validation_dataloader = DataLoader(
    val_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True)

In [None]:
def dice_loss(pred, target, num_classes=4):
    smooth = 1.
    dice_per_class = torch.zeros(num_classes).to(pred.device)

    for class_id in range(num_classes):
        pred_class = pred[:, class_id, ...]
        target_class = target[:, class_id, ...]

        intersection = torch.sum(pred_class * target_class)
        A_sum = torch.sum(pred_class * pred_class)
        B_sum = torch.sum(target_class * target_class)

        dice_per_class[class_id] = 1 - \
            (2. * intersection + smooth) / (A_sum + B_sum + smooth)

    return torch.mean(dice_per_class)

model = smp.MAnet(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=4,                      # model output channels (number of classes in your dataset)
).to(device)
optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()), lr=params['lr'], betas=(params['beta1'], params['beta2']))


In [4]:
summary(model, input_size=(1, 3, 512, 512))


Layer (type:depth-idx)                                                                                    Output Shape              Param #
Mask2FormerForUniversalSegmentation                                                                       [1, 100, 256]             --
├─Mask2FormerModel: 1-1                                                                                   [1, 100, 128, 128]        --
│    └─Mask2FormerPixelLevelModule: 2-1                                                                   [1, 256, 16, 16]          --
│    │    └─SwinBackbone: 3-1                                                                             [1, 96, 128, 128]         48,838,602
│    │    └─Mask2FormerPixelDecoder: 3-2                                                                  [1, 256, 128, 128]        5,421,504
│    └─Mask2FormerTransformerModule: 2-2                                                                  [100, 1, 256]             51,968
│    │    └─Mask2FormerSinePosi

In [None]:

train_loss_list = []
val_loss_list = []
train_acc_list = []
val_acc_list = []
metrics = defaultdict(float)
for epoch in range(300):
    train = tqdm(train_dataloader)
    count = 0
    running_loss = 0.0
    acc_loss = 0
    for x, y in train:
        model.train()
        y = y.to(device).float()
        count += 1
        x = x.to(device).float()
        optimizer.zero_grad()  # optimizer zero 로 초기화
        predict = model(x).logits.to(device)
        cost = dice_loss(predict, y)  # cost 구함
        acc = 1-dice_loss(predict, y)
        cost.backward()  # cost에 대한 backward 구함
        optimizer.step()
        running_loss += cost.item()
        acc_loss += acc
        y = y.to('cpu')

        x = x.to('cpu')
        train.set_description(
            f"epoch: {epoch+1}/{300} Step: {count+1} dice_loss : {running_loss/count:.4f} dice_score: {1-running_loss/count:.4f}")
    train_loss_list.append((running_loss/count))
    train_acc_list.append((acc_loss/count).cpu().detach().numpy())

In [None]:
model(x).logits.shape