In [None]:
#importai
import os
from PIL import Image, UnidentifiedImageError
import torch
import random
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
import torchvision
import numpy as np
from torch.optim.lr_scheduler import StepLR

from tqdm import tqdm
import torch.optim as optim



In [None]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)
###
class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # DOWN PART UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # UP PART UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)
            )
            self.ups.append(DoubleConv(feature * 2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx // 2]
            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:])
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx + 1](concat_skip)
        return self.final_conv(x)


In [None]:
class DATASET(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(img_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.img_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpeg", ".png"))

        image = Image.open(img_path).convert("RGB").resize((256, 256), Image.BILINEAR)
        mask = Image.open(mask_path).convert("L").resize((256, 256), Image.NEAREST)

        if self.transform is not None:
            image = self.transform(image)
            mask = transforms.ToTensor()(mask)
        else:
            image = transforms.ToTensor()(image)
            mask = transforms.ToTensor()(mask)

        mask = mask.float()
        mask[mask == 255.0] = 1.0  # Preprocess the mask
        return image, mask

In [None]:
#Hyper parameters
LEARNING_RATE = 1e-4
DEVICE = "cuda"
BATCH_SIZE = 5
NUM_EPOCHS = 5
NUM_WORKERS = 4
IMAGE_HEIGHT = 256 
IMAGE_WIDTH = 256
PIN_MEMORY = True
LOAD_MODEL = True
TRAIN_IMG_DIR = "../data/train"
TRAIN_MASK_DIR = "../data/train_mask"
VAL_IMG_DIR = '../data/test'
VAL_MASK_DIR = '../data/test_mask'
THRESHOLD = 0.65

In [None]:

def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])


def get_loaders(TRAIN_IMG_DIR, TRAIN_MASK_DIR, VAL_IMG_DIR, VAL_MASK_DIR, BATCH_SIZE, train_transform, val_transform, NUM_WORKERS, PIN_MEMORY):
    train_dataset= DATASET(img_dir= TRAIN_IMG_DIR, mask_dir= TRAIN_MASK_DIR, transform= train_transform)
    train_loader= DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, shuffle= True) # 
    val_dataset = DATASET(img_dir = VAL_IMG_DIR,mask_dir= VAL_MASK_DIR, transform= val_transform)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, shuffle=True) # 
    return train_loader, val_loader

def check_accuracy(loader, model, device=DEVICE):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    
    true_positives = 0
    false_positives = 0
    false_negatives = 0
    true_negatives = 0
    iou_score = 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            preds = torch.sigmoid(model(x))
            #Bandziau thresholda daryt 0.7 bet tai nukirpo sonus dokumento
            preds = (preds > THRESHOLD).float()

            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)
            #
            intersection = (preds * y).sum()
            union = preds.sum() + y.sum() - intersection
            iou_score += (intersection / (union + 1e-8))
            #
            
            true_positives += ((preds == 1) & (y == 1)).sum()
            false_positives += ((preds == 1) & (y == 0)).sum()
            false_negatives += ((preds == 0) & (y == 1)).sum()
            true_negatives += ((preds == 0) & (y == 0)).sum()

    precision = true_positives / (true_positives + false_positives + 1e-8)
    recall = true_positives / (true_positives + false_negatives + 1e-8)
    specificity = true_negatives / (true_negatives + false_positives + 1e-8)
    accuracy = num_correct / num_pixels * 100
    dice = dice_score / len(loader)
    iou = iou_score / len(loader)
    # Maziau jautrus negu dice maziems objektams, kadangi cia dideli plotai - Iou geriau rodo performanca
    #iou_score += ((preds * y).sum() / (preds.sum() + y.sum() - (preds * y).sum() + 1e-8))

    print(f"Got {num_correct}/{num_pixels} with acc {accuracy:.2f}") # Pixel accuracy
    print(f"Dice score: {dice}") # Dice coefficient
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"Specificity: {specificity:.4f}")
    print(f"IoU: {iou:.4f}")
    model.train()
    return accuracy.item(), dice.item(), precision.item(), recall.item(), specificity.item(), iou.item()

def save_predictions_as_imgs(loader, model, pred_folder="../reports/predictions/", true_folder="../reports/true_images/", device=DEVICE):
    if not os.path.exists(pred_folder):
        os.makedirs(pred_folder)
    if not os.path.exists(true_folder):
        os.makedirs(true_folder)
        
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device)
        y = y.to(device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > THRESHOLD).float()
        torchvision.utils.save_image(preds, f"{pred_folder}/pred_{idx}.png")
        torchvision.utils.save_image(y, f"{true_folder}/true_{idx}.png")
    model.train()

In [None]:
def train_fn(loader, model, optimizer, loss_fn, scaler=None):
    loop = tqdm(loader)
    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(DEVICE)
        targets = targets.float().to(DEVICE)

        # Forward
        predictions = model(data)
        loss = loss_fn(predictions, targets)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loop.set_postfix(loss=loss.item())
    pass


def main():
    train_transform = transforms.Compose([
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
        transforms.RandomRotation(35),
        transforms.RandomHorizontalFlip(0.5),
        transforms.RandomVerticalFlip(0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])
    ])

    val_transforms = transforms.Compose([
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])
    ])

    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = StepLR(optimizer, step_size=20, gamma=0.1) 
    
    train_loader, val_loader = get_loaders(TRAIN_IMG_DIR, TRAIN_MASK_DIR, VAL_IMG_DIR, VAL_MASK_DIR, BATCH_SIZE, train_transform, val_transforms, NUM_WORKERS, PIN_MEMORY)
    #Metriku laikytojas
    accuracies = []
    dice_scores = []
    precisions = []
    recalls = []
    specificities = []
    ious=[]
    #TRAIn
    t_accuracies = []
    t_dice_scores = []
    t_precisions = []
    t_recalls = []
    t_specificities = []
    t_ious=[]

    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn)

        accuracy, dice, precision, recall, specificity,iou = check_accuracy(val_loader, model, device=DEVICE)
        accuracies.append(accuracy)
        dice_scores.append(dice)
        precisions.append(precision)
        recalls.append(recall)
        specificities.append(specificity)
        ious.append(iou)
        #
        t_accuracy, t_dice, t_precision, t_recall, t_specificity,t_iou = check_accuracy(train_loader, model, device=DEVICE)
        t_accuracies.append(t_accuracy)
        t_dice_scores.append(t_dice)
        t_precisions.append(t_precision)
        t_recalls.append(t_recall)
        t_specificities.append(t_specificity)
        t_ious.append(t_iou)
        #
        scheduler.step()
        # Checkpoint
        #checkpoint = {"state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
        #save_checkpoint(checkpoint)
        save_predictions_as_imgs(val_loader, model,  pred_folder="../reports/predictions/", true_folder="../reports/true_images/")

    torch.save(model.state_dict(), "final_model.pth")
    print("Model saved as final_model.pth")
    
    dummy_input = torch.randn(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH).to(DEVICE)
    torch.onnx.export(model, dummy_input, "final_model.onnx", 
                      input_names=['input'], output_names=['output'], 
                      dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
    print("Model saved as final_model.onnx")

    #
    metrics_df = pd.DataFrame({
        'epoch': list(range(1, NUM_EPOCHS + 1)),
        'accuracy': accuracies,
        'dice': dice_scores,
        'precision': precisions,
        'recall': recalls,
        'specificity': specificities,
        'IoU': ious,
        't_accuracy': t_accuracies,
        't_dice': t_dice_scores,
        't_precision': t_precisions,
        't_recall': t_recalls,
        't_specificity': t_specificities,
        't_IoU': t_ious
    })

    def plot_metrics(df, metric, folder):
        plt.figure(figsize=(10, 6))
        plt.plot(df['epoch'], df[metric], marker='o', label=f'Validation {metric}')
        plt.plot(df['epoch'], df[f't_{metric}'], marker='x', label=f'Training {metric}')
        plt.title(f'{metric.capitalize()} over Epochs')
        plt.xlabel('Epoch')
        plt.ylabel(metric.capitalize())
        plt.legend()
        plt.grid(True)
        plt.savefig(f"{folder}/{metric}.png")
        plt.close()

    metrics = ['accuracy', 'dice', 'precision', 'recall', 'specificity', 'IoU']
    for metric in metrics:
        plot_metrics(metrics_df, metric, "../reports")
    
if __name__ == "__main__":
    main()



In [None]:
print(UNET())

In [None]:
#Thresholdu nusprendimui - 0.65
train_transform = transforms.Compose([
    transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
    transforms.RandomRotation(35),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomVerticalFlip(0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])
    ])

val_transforms = transforms.Compose([
    transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])
    ])

model = UNET(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = StepLR(optimizer, step_size=20, gamma=0.1) 
    
train_loader, val_loader = get_loaders(TRAIN_IMG_DIR, TRAIN_MASK_DIR, VAL_IMG_DIR, VAL_MASK_DIR, BATCH_SIZE, train_transform, val_transforms, NUM_WORKERS, PIN_MEMORY)
    #Metriku laikytojas
accuracies = []
dice_scores = []
precisions = []
recalls = []
specificities = []
ious=[]


#  the range of thresholds
thresholds = np.arange(0.0, 1.05, 0.05)
best_threshold = 0.0
best_iou = 0.0

for threshold in thresholds:
    THRESHOLD=threshold
    print(f"Training with threshold: {threshold}")
    
    # Is naujo paleisti modeli su nauju thresholdu
    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = StepLR(optimizer, step_size=20, gamma=0.1)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR, TRAIN_MASK_DIR, VAL_IMG_DIR, VAL_MASK_DIR, 
        BATCH_SIZE, train_transform, val_transforms, NUM_WORKERS, PIN_MEMORY
    )
    
    # Train the model
    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn)
        scheduler.step()

    # Evaluate the model
    accuracy, dice, precision, recall, specificity, iou = check_accuracy(val_loader, model, device=DEVICE)
    print(f"Threshold: {threshold}, IoU: {iou}")

    # Update the best threshold if the current IoU is higher
    if iou > best_iou:
        best_iou = iou
        best_threshold = threshold

print(f"Best threshold: {best_threshold}, Best IoU: {best_iou}")
