In [1]:
import torch
import torchvision.transforms.v2 as transforms
from torchvision.transforms import InterpolationMode
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as segm
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
from dataset import PolypDataset
from torchgeometry.losses import DiceLoss

In [2]:
RANDOM_SEED = 34
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
transform = transforms.Compose([
        transforms.RandomResizedCrop((256,256),[0.8,1]),
        transforms.RandomApply([transforms.RandomChoice([transforms.RandomRotation(degrees=(0, 180)),transforms.RandomPerspective(distortion_scale=0.6,p=1)])],p=0.5),
        transforms.RandomVerticalFlip(0.5),
        transforms.RandomHorizontalFlip(0.5),
])
# transform = None
random_augments = transforms.RandomApply([transforms.RandomChoice([transforms.RandomPosterize(bits=3),transforms.RandomAdjustSharpness(sharpness_factor=2),transforms.RandomEqualize(),transforms.GaussianBlur(3),]),
                                          transforms.RandomAutocontrast(0.5),
                                          transforms.RandomGrayscale(0.1)],p=0.5)

test_transforms = transforms.Compose([transforms.Resize((256,256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), #ImageNet mean and std
])



In [4]:
dataset = PolypDataset(image_dir='../data', transform=transform, random_transforms = random_augments,test_transform=test_transforms)
train_dataset,valid_dataset = dataset.split_train_test(test_size=0.2,random_seed=RANDOM_SEED)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
valid_dataloader = DataLoader(valid_dataset, batch_size=4, shuffle=True, num_workers=4)

In [5]:
print(len(train_dataset),len(valid_dataset))

800 200


In [6]:
# Ensure model initialization the same
torch.manual_seed(34)

<torch._C.Generator at 0x1e47eead810>

In [8]:
net = segm.UnetPlusPlus(encoder_name = "resnet50",encoder_depth = 5, encoder_weights = 'imagenet', in_channels=3, classes=3, decoder_channels=(256,128,64,32,16),decoder_use_batchnorm=True)

In [None]:
adam = torch.optim.Adam(net.parameters(), lr=0.0001)
weights = torch.Tensor([[0.4, 0.55, 0.05]]).to(device)
criterion_1 = DiceLoss().to(device)
criterion_2 = torch.nn.CrossEntropyLoss(weight = weights).to(device)
num_epochs = 50
net.to(device)

In [None]:
train_losses = []
val_losses = []
best_val_loss = 99999

In [None]:
import wandb
wandb.login(
    key = "b837839166bd4f97a07e90a26fa965ee17f8b64f"
)
wandb.init(
    project = "PolypSegmentTest"
)

In [None]:
for epoch in range(num_epochs):
    net.train()
    train_loss = 0
    for images, masks in tqdm(train_dataloader):
        images = images.to(device)
        masks = masks.to(device)
        adam.zero_grad()
        outputs = net(images)
        loss_1 = criterion_1(outputs, masks)
        loss_2 = criterion_2(outputs, masks)
        loss = (loss_1 + loss_2)/2
        loss.backward()
        adam.step()
        train_loss += loss.item()
    train_losses.append(train_loss/len(train_dataloader))

    net.eval()
    val_loss = 0
    for images, masks in tqdm(valid_dataloader):
        images = images.to(device)
        masks = masks.to(device)
        outputs = net(images)
        loss_1 = criterion_1(outputs, masks)
        loss_2 = criterion_2(outputs, masks)
        loss = (loss_1 + loss_2)/2
        val_loss += loss.item()
    val_losses.append(val_loss/len(valid_dataloader))
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        checkpoint = { 
            'epoch': epoch,
            'model': net.state_dict(),
            'optimizer': adam.state_dict(),
            'loss': val_loss,
        }
        save_path = f'model.pth'
        torch.save(checkpoint, save_path)
    wandb.log({'Val_loss': val_losses[-1],'Train_loss': train_losses[-1]})

    print(f"Epoch {epoch+1}/{num_epochs} Train Loss: {train_losses[-1]} Val Loss: {val_losses[-1]}")

In [9]:
torch.save(checkpoint["model"],"weight.pth")

In [10]:
checkpoint = torch.load(f"weight.pth")
net.load_state_dict(checkpoint)
net.eval()
net.to(device)

UnetPlusPlus(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential

In [None]:
import cv2
import os

color_dict= {2: (0, 0, 0),
             0: (255, 0, 0),
             1: (0, 255, 0)}

def mask_to_rgb(mask, color_dict):
    output = np.zeros((mask.shape[0], mask.shape[1], 3))

    for k in color_dict.keys():
        output[mask==k] = color_dict[k]

    return np.uint8(output) 

for i in tqdm(os.listdir("../data/test/test")):
    img_path = os.path.join("../data/test/test", i)
    ori_img = cv2.imread(img_path)
    ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGSB)
    ori_w = ori_img.shape[0]
    ori_h = ori_img.shape[1]
    img = cv2.resize(ori_img, (256, 256))
    img = test_transforms(img)
    input_img = img.unsqueeze(0).to(device)
    output_mask = net(input_img).squeeze(0).cpu().detach().numpy().transpose(1,2,0)
    mask = cv2.resize(output_mask, (ori_h, ori_w))
    mask = np.argmax(mask, axis=2)
    mask_rgb = mask_to_rgb(mask, color_dict)
    mask_rgb = cv2.cvtColor(mask_rgb, cv2.COLOR_RGB2BGR)
    cv2.imwrite("./prediction/{}".format(i), mask_rgb) 

  0%|          | 0/200 [00:00<?, ?it/s]

In [11]:
def rle_to_string(runs):
    return ' '.join(str(x) for x in runs)

def rle_encode_one_mask(mask):
    pixels = mask.flatten()
    pixels[pixels > 225] = 255
    pixels[pixels <= 225] = 0
    use_padding = False
    if pixels[0] or pixels[-1]:
        use_padding = True
        pixel_padded = np.zeros([len(pixels) + 2], dtype=pixels.dtype)
        pixel_padded[1:-1] = pixels
        pixels = pixel_padded
    rle = np.where(pixels[1:] != pixels[:-1])[0] + 2
    if use_padding:
        rle = rle - 1
    rle[1::2] = rle[1::2] - rle[:-1:2]
    
    return rle_to_string(rle)

def rle2mask(mask_rle, shape=(3,3)):
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T

def mask2string(dir):
    strings = []
    ids = []
    ws, hs = [[] for i in range(2)]
    for image_id in os.listdir(dir):
        id = image_id.split('.')[0]
        path = os.path.join(dir, image_id)
        print(path)
        img = cv2.imread(path)[:,:,::-1]
        h, w = img.shape[0], img.shape[1]
        for channel in range(2):
            ws.append(w)
            hs.append(h)
            ids.append(f'{id}_{channel}')
            string = rle_encode_one_mask(img[:,:,channel])
            strings.append(string)
    r = {
        'ids': ids,
        'strings': strings,
    }
    return r


MASK_DIR_PATH = './prediction'
dir = MASK_DIR_PATH
res = mask2string(dir)
df = pd.DataFrame(columns=['Id', 'Expected'])
df['Id'] = res['ids']
df['Expected'] = res['strings']

df.to_csv(r'output.csv', index=False)

./prediction\019410b1fcf0625f608b4ce97629ab55.jpeg
./prediction\02fa602bb3c7abacdbd7e6afd56ea7bc.jpeg
./prediction\0398846f67b5df7cdf3f33c3ca4d5060.jpeg
./prediction\05734fbeedd0f9da760db74a29abdb04.jpeg
./prediction\05b78a91391adc0bb223c4eaf3372eae.jpeg
./prediction\0619ebebe9e9c9d00a4262b4fe4a5a95.jpeg
./prediction\0626ab4ec3d46e602b296cc5cfd263f1.jpeg
./prediction\0a0317371a966bf4b3466463a3c64db1.jpeg
./prediction\0a5f3601ad4f13ccf1f4b331a412fc44.jpeg
./prediction\0af3feff05dec1eb3a70b145a7d8d3b6.jpeg
./prediction\0fca6a4248a41e8db8b4ed633b456aaa.jpeg
./prediction\1002ec4a1fe748f3085f1ce88cbdf366.jpeg
./prediction\1209db6dcdda5cc8a788edaeb6aa460a.jpeg
./prediction\13dd311a65d2b46d0a6085835c525af6.jpeg
./prediction\1531871f2fd85a04faeeb2b535797395.jpeg
./prediction\15fc656702fa602bb3c7abacdbd7e6af.jpeg
./prediction\1ad4f13ccf1f4b331a412fc44655fb51.jpeg
./prediction\1b62f15ec83b97bb11e8e0c4416c1931.jpeg
./prediction\1c0e9082ea2c193ac8d551c149b60f29.jpeg
./prediction\1db239dda50f954ba5