In [1]:
import os
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset


class CoronarySmallDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])
        

        image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
        image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
        image = cv2.resize(image, (256, 256))
        
        mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
        mask = cv2.cvtColor(mask, cv2.COLOR_RGBA2RGB)
        mask = cv2.resize(mask, (256, 256))
        
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        
        return image, mask

In [2]:
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torchvision import transforms


transform = transforms.Compose([
    transforms.ToTensor()
])

train_image_dir = 'images_train\input'
train_mask_dir = 'images_train\output'
val_image_dir = 'images_val\input'
val_mask_dir = 'images_val\output'
test_image_dir = 'images_test\input'
test_mask_dir = 'images_test\output'

train_dataset = CoronarySmallDataset(train_image_dir, train_mask_dir, transform=transform)
val_dataset = CoronarySmallDataset(val_image_dir, val_mask_dir, transform=transform)
test_dataset = CoronarySmallDataset(test_image_dir, test_mask_dir, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)


In [3]:
from large_RGB_model import UNet


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet()
model = model.to(device)

In [4]:
import torch.optim as optim
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F


criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=50):
    best_loss = float('inf')
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for images, masks in tqdm(train_loader):
            images = images.to(device)
            masks = masks.to(device)

            # print(images.size())
            # print(images)
            # print(masks.size())
            # print(masks)

            optimizer.zero_grad()
            outputs = model(images)

            # print(outputs.size())
            # print(outputs)
            
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * images.size(0)
        
        train_loss = train_loss / len(train_loader.dataset)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device)

                outputs = model(images)
                loss = criterion(outputs, masks)
                
                val_loss += loss.item() * images.size(0)
        
        val_loss = val_loss / len(val_loader.dataset)
        
        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
        
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), 'large_RGB_model.pth')


In [5]:
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=55)

100%|██████████| 70/70 [04:50<00:00,  4.15s/it]


Epoch 1/55, Train Loss: 0.6332, Val Loss: 0.5480


100%|██████████| 70/70 [05:06<00:00,  4.38s/it]


Epoch 2/55, Train Loss: 0.5094, Val Loss: 0.5121


100%|██████████| 70/70 [05:09<00:00,  4.42s/it]


Epoch 3/55, Train Loss: 0.4648, Val Loss: 0.4418


100%|██████████| 70/70 [05:11<00:00,  4.45s/it]


Epoch 4/55, Train Loss: 0.4220, Val Loss: 0.4160


100%|██████████| 70/70 [05:06<00:00,  4.38s/it]


Epoch 5/55, Train Loss: 0.3811, Val Loss: 0.3605


100%|██████████| 70/70 [05:02<00:00,  4.33s/it]


Epoch 6/55, Train Loss: 0.3472, Val Loss: 0.3395


100%|██████████| 70/70 [05:02<00:00,  4.32s/it]


Epoch 7/55, Train Loss: 0.3148, Val Loss: 0.3036


100%|██████████| 70/70 [05:00<00:00,  4.30s/it]


Epoch 8/55, Train Loss: 0.2842, Val Loss: 0.2789


100%|██████████| 70/70 [05:01<00:00,  4.31s/it]


Epoch 9/55, Train Loss: 0.2568, Val Loss: 0.2452


100%|██████████| 70/70 [05:02<00:00,  4.32s/it]


Epoch 10/55, Train Loss: 0.2323, Val Loss: 0.2238


100%|██████████| 70/70 [05:00<00:00,  4.30s/it]


Epoch 11/55, Train Loss: 0.2121, Val Loss: 0.2043


100%|██████████| 70/70 [05:01<00:00,  4.30s/it]


Epoch 12/55, Train Loss: 0.1943, Val Loss: 0.1922


100%|██████████| 70/70 [04:58<00:00,  4.26s/it]


Epoch 13/55, Train Loss: 0.1792, Val Loss: 0.1784


100%|██████████| 70/70 [04:57<00:00,  4.26s/it]


Epoch 14/55, Train Loss: 0.1654, Val Loss: 0.1590


100%|██████████| 70/70 [04:58<00:00,  4.26s/it]


Epoch 15/55, Train Loss: 0.1531, Val Loss: 0.1512


100%|██████████| 70/70 [05:00<00:00,  4.29s/it]


Epoch 16/55, Train Loss: 0.1429, Val Loss: 0.1433


100%|██████████| 70/70 [05:00<00:00,  4.29s/it]


Epoch 17/55, Train Loss: 0.1329, Val Loss: 0.1344


100%|██████████| 70/70 [05:00<00:00,  4.29s/it]


Epoch 18/55, Train Loss: 0.1245, Val Loss: 0.1216


100%|██████████| 70/70 [05:03<00:00,  4.33s/it]


Epoch 19/55, Train Loss: 0.1171, Val Loss: 0.1153


100%|██████████| 70/70 [05:01<00:00,  4.31s/it]


Epoch 20/55, Train Loss: 0.1111, Val Loss: 0.1095


100%|██████████| 70/70 [05:00<00:00,  4.29s/it]


Epoch 21/55, Train Loss: 0.1049, Val Loss: 0.1042


100%|██████████| 70/70 [05:02<00:00,  4.32s/it]


Epoch 22/55, Train Loss: 0.1001, Val Loss: 0.0999


100%|██████████| 70/70 [05:01<00:00,  4.31s/it]


Epoch 23/55, Train Loss: 0.0954, Val Loss: 0.0977


100%|██████████| 70/70 [05:01<00:00,  4.31s/it]


Epoch 24/55, Train Loss: 0.0915, Val Loss: 0.0909


100%|██████████| 70/70 [05:00<00:00,  4.29s/it]


Epoch 25/55, Train Loss: 0.0875, Val Loss: 0.0984


100%|██████████| 70/70 [04:57<00:00,  4.24s/it]


Epoch 26/55, Train Loss: 0.0841, Val Loss: 0.0840


100%|██████████| 70/70 [05:00<00:00,  4.29s/it]


Epoch 27/55, Train Loss: 0.0811, Val Loss: 0.0851


100%|██████████| 70/70 [05:01<00:00,  4.31s/it]


Epoch 28/55, Train Loss: 0.0783, Val Loss: 0.0813


100%|██████████| 70/70 [04:59<00:00,  4.28s/it]


Epoch 29/55, Train Loss: 0.0757, Val Loss: 0.0761


100%|██████████| 70/70 [05:02<00:00,  4.33s/it]


Epoch 30/55, Train Loss: 0.0735, Val Loss: 0.0792


100%|██████████| 70/70 [05:01<00:00,  4.31s/it]


Epoch 31/55, Train Loss: 0.0712, Val Loss: 0.0742


100%|██████████| 70/70 [05:01<00:00,  4.31s/it]


Epoch 32/55, Train Loss: 0.0688, Val Loss: 0.0711


100%|██████████| 70/70 [04:59<00:00,  4.28s/it]


Epoch 33/55, Train Loss: 0.0670, Val Loss: 0.0721


100%|██████████| 70/70 [04:59<00:00,  4.28s/it]


Epoch 34/55, Train Loss: 0.0653, Val Loss: 0.0683


100%|██████████| 70/70 [04:58<00:00,  4.27s/it]


Epoch 35/55, Train Loss: 0.0632, Val Loss: 0.0697


100%|██████████| 70/70 [04:56<00:00,  4.24s/it]


Epoch 36/55, Train Loss: 0.0620, Val Loss: 0.0669


100%|██████████| 70/70 [04:56<00:00,  4.24s/it]


Epoch 37/55, Train Loss: 0.0601, Val Loss: 0.0659


100%|██████████| 70/70 [05:00<00:00,  4.29s/it]


Epoch 38/55, Train Loss: 0.0584, Val Loss: 0.0652


100%|██████████| 70/70 [04:59<00:00,  4.28s/it]


Epoch 39/55, Train Loss: 0.0574, Val Loss: 0.0636


100%|██████████| 70/70 [04:59<00:00,  4.27s/it]


Epoch 40/55, Train Loss: 0.0559, Val Loss: 0.0632


100%|██████████| 70/70 [04:56<00:00,  4.24s/it]


Epoch 41/55, Train Loss: 0.0543, Val Loss: 0.0622


100%|██████████| 70/70 [05:00<00:00,  4.29s/it]


Epoch 42/55, Train Loss: 0.0527, Val Loss: 0.0618


100%|██████████| 70/70 [05:01<00:00,  4.31s/it]


Epoch 43/55, Train Loss: 0.0514, Val Loss: 0.0618


100%|██████████| 70/70 [05:01<00:00,  4.31s/it]


Epoch 44/55, Train Loss: 0.0497, Val Loss: 0.0598


100%|██████████| 70/70 [05:00<00:00,  4.30s/it]


Epoch 45/55, Train Loss: 0.0484, Val Loss: 0.0600


100%|██████████| 70/70 [05:01<00:00,  4.31s/it]


Epoch 46/55, Train Loss: 0.0471, Val Loss: 0.0585


100%|██████████| 70/70 [04:58<00:00,  4.27s/it]


Epoch 47/55, Train Loss: 0.0462, Val Loss: 0.0601


100%|██████████| 70/70 [05:00<00:00,  4.29s/it]


Epoch 48/55, Train Loss: 0.0450, Val Loss: 0.0590


100%|██████████| 70/70 [04:57<00:00,  4.24s/it]


Epoch 49/55, Train Loss: 0.0436, Val Loss: 0.0580


100%|██████████| 70/70 [04:59<00:00,  4.28s/it]


Epoch 50/55, Train Loss: 0.0424, Val Loss: 0.0580


100%|██████████| 70/70 [04:59<00:00,  4.28s/it]


Epoch 51/55, Train Loss: 0.0414, Val Loss: 0.0580


100%|██████████| 70/70 [04:59<00:00,  4.28s/it]


Epoch 52/55, Train Loss: 0.0401, Val Loss: 0.0574


100%|██████████| 70/70 [05:02<00:00,  4.33s/it]


Epoch 53/55, Train Loss: 0.0392, Val Loss: 0.0563


100%|██████████| 70/70 [05:06<00:00,  4.38s/it]


Epoch 54/55, Train Loss: 0.0380, Val Loss: 0.0563


100%|██████████| 70/70 [05:06<00:00,  4.37s/it]


Epoch 55/55, Train Loss: 0.0374, Val Loss: 0.0564


In [5]:
import os
import numpy as np
import cv2
import torch
from torchvision import transforms


model.load_state_dict(torch.load('large_RGB_model.pth'))
model.eval()

def show_image(type, image_name):
    dir = f"images_test\{type}"
    print(dir)
    img = cv2.imread(os.path.join(dir, image_name), cv2.IMREAD_UNCHANGED)
    img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
    img = cv2.resize(img, (512, 512))
    cv2.imshow(type, img)

def predict(image_name):
    dir = 'images_test\input'
    img = cv2.imread(os.path.join(dir, image_name), cv2.IMREAD_UNCHANGED)
    img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
    img = cv2.resize(img, (256, 256))
    img = transforms.ToTensor()(img).unsqueeze(0).to(device)

    with torch.no_grad():
        pred = model(img)
        pred = pred.squeeze().cpu().numpy()
        print(pred)
        # pred = (pred > 0.5).astype(np.uint8)
    R, G, B = pred[0], pred[1], pred[2]
    # cv2.imshow('R', R)
    # cv2.imshow('G', G)
    # cv2.imshow('B', B)
    pred_image = np.zeros((256, 256, 3), dtype=np.uint8)
    pred_image[..., 0] = (R * 255).astype(np.uint8)
    pred_image[..., 1] = (G * 255).astype(np.uint8)
    pred_image[..., 2] = (B * 255).astype(np.uint8)
    # print(pred_image)
    # cv2.imshow('pred0', pred_image)
    bigger_image = cv2.resize(pred_image, (512, 512))
    cv2.imshow('pred', bigger_image)
    cv2.waitKey(0)
   

In [8]:
# image_name = "131aedfhs6pnf1fvtvp49mhdb2fucqzo22_29.png"
image_name = "131aedfhs6pnf1fvtvp49mld7mqexnz322_36.png"
# image_name = "131aedfhs6pnf1fvtvp49mia892s56cf22_28.png"
# image_name = "131aedfhs6pnf1fvtvp49juwu7plj9dv22_40.png"
show_image("input", image_name)
show_image("output", image_name)
predict(image_name)

images_test\input
images_test\output
[[[0.00787001 0.00681553 0.00667353 ... 0.00639666 0.01160035 0.01274195]
  [0.0039624  0.00434883 0.00490309 ... 0.00631991 0.00654971 0.00713173]
  [0.00309924 0.00680399 0.00406838 ... 0.0054316  0.00615383 0.00814165]
  ...
  [0.00676593 0.00556611 0.00742547 ... 0.01003035 0.0077699  0.00923594]
  [0.00460479 0.00467522 0.0073868  ... 0.01161776 0.01015309 0.01873149]
  [0.00919656 0.00782342 0.00772655 ... 0.00801659 0.00969227 0.01574487]]

 [[0.00628947 0.00549719 0.00526443 ... 0.00481923 0.00664366 0.00915967]
  [0.0024287  0.00320141 0.00443484 ... 0.00417889 0.00429756 0.00535465]
  [0.00246641 0.0045667  0.00393924 ... 0.00353131 0.00331387 0.00578081]
  ...
  [0.00540575 0.00494929 0.00539099 ... 0.0070201  0.00636626 0.00596943]
  [0.00279153 0.0031672  0.00443246 ... 0.00558795 0.00601047 0.00926846]
  [0.00584357 0.00526194 0.00663068 ... 0.0060198  0.00765639 0.01367984]]

 [[0.0063608  0.00603837 0.00610488 ... 0.00726714 0.009912

In [None]:
image_name = "131aedfhs6pnf1fvtvp49mia892s56cf22_28.png"
# image_name = "131aedfhs6pnf1fvtvp49juwu7plj9dv22_40.png"
show_image("input", image_name)
show_image("output", image_name)
predict(image_name)

In [None]:
image_name = "131aedfhs6pnf1fvtvp49mia892s56cf22_28.png"
# image_name = "131aedfhs6pnf1fvtvp49juwu7plj9dv22_40.png"
show_image("input", image_name)
show_image("output", image_name)
predict(image_name)

In [11]:
model.load_state_dict(torch.load('large_RGB_model.pth'))

def dice_coefficient(y_true, y_pred):
    smooth = 1.0
    y_true_f = y_true.view(-1)
    y_pred_f = y_pred.view(-1)
    intersection = (y_true_f * y_pred_f).sum()
    return (2. * intersection + smooth) / (y_true_f.sum() + y_pred_f.sum() + smooth)

model.eval()
test_loss = 0.0
dice_scores = []

with torch.no_grad():
    for images, masks in test_loader:
        images = images.to(device)
        masks = masks.to(device)

        outputs = model(images)
        loss = criterion(outputs, masks)
        
        test_loss += loss.item() * images.size(0)
        dice_scores.append(dice_coefficient(masks, outputs))

test_loss = test_loss / len(test_loader.dataset)
mean_dice = torch.mean(torch.tensor(dice_scores))

print(f'Test Loss: {test_loss:.4f}')
print(f'Mean Dice Coefficient: {mean_dice:.4f}')


Test Loss: 0.5773
Mean Dice Coefficient: 0.6904


In [None]:
from skimage import morphology

def post_process(prediction):
    prediction_np = prediction.cpu().numpy()
    processed = morphology.remove_small_objects(prediction_np > 0.5, min_size=100)
    processed = morphology.remove_small_holes(processed, area_threshold=100)
    return torch.tensor(processed, device=device)

post_processed_predictions = [post_process(pred) for pred in outputs]


In [None]:
from flask import Flask, request, jsonify
import io

app = Flask(__name__)
model.load_state_dict(torch.load('model.pth'))
model.eval()

@app.route('/predict', methods=['POST'])
def predict():
    file = request.files['image']
    img_bytes = file.read()
    img = cv2.imdecode(np.frombuffer(img_bytes, np.uint8), cv2.IMREAD_UNCHANGED)
    img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
    img = cv2.resize(img, (512, 512))
    img = transforms.ToTensor()(img).unsqueeze(0).to(device)

    with torch.no_grad():
        pred = model(img)
        pred = pred.squeeze().cpu().numpy()
        pred = (pred > 0.5).astype(np.uint8)
    
    _, buffer = cv2.imencode('.png', pred)
    response = io.BytesIO(buffer)
    response.seek(0)
    return response, 200, {'Content-Type': 'image/png'}

if __name__ == '__main__':
    app.run()
