In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from PIL import Image
import os
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import random
import torch.nn.functional as F
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import cv2
import segmentation_models_pytorch as smp

folder_path = "20230530_segm_black_mouse_mnSLA_red_and_black_back"

seed_value = 52
torch.manual_seed(seed_value)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [12]:
def augmentation():
    count_image = 350
    path_images = "20230530_segm_black_mouse_mnSLA_red_and_black_back/images"
    path_masks = "20230530_segm_black_mouse_mnSLA_red_and_black_back/masks"

    image_files = os.listdir(path_images)
    selected_files_for_vertical = random.sample(image_files, count_image)

    for filename in selected_files_for_vertical:
        image_path = os.path.join(path_images, filename)
        image = Image.open(image_path)
        rotated_image = image.transpose(Image.FLIP_LEFT_RIGHT)
        output_path = os.path.join(path_images, f'revert_vertical_{filename}')
        rotated_image.save(output_path)
        image.close()

        mask_path = os.path.join(path_masks, filename)
        mask = Image.open(mask_path)
        rotated_mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
        output_path = os.path.join(path_masks, f'revert_vertical_{filename}')
        rotated_mask.save(output_path)
        mask.close()

    # selected_files_for_horizontal = []
    count_image = 1
    selected_files_for_horizontal = random.sample(image_files, count_image)

    for filename in selected_files_for_horizontal:
        image_path = os.path.join(path_images, filename)
        image = Image.open(image_path)
        rotated_image = image.transpose(Image.FLIP_TOP_BOTTOM)
        output_path = os.path.join(path_images, f'revert_horizontal_{filename}')
        rotated_image.save(output_path)
        image.close()

        mask_path = os.path.join(path_masks, filename)
        mask = Image.open(mask_path)
        rotated_mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
        output_path = os.path.join(path_masks, f'revert_horizontal_{filename}')
        rotated_mask.save(output_path)
        mask.close()

    return selected_files_for_vertical, selected_files_for_horizontal


def delete_generated_images(data_vert, data_hor):
    path_images = "20230530_segm_black_mouse_mnSLA_red_and_black_back/images"
    path_masks = "20230530_segm_black_mouse_mnSLA_red_and_black_back/masks"

    for filename in data_vert:
        output_image_path = os.path.join(path_images, f'revert_vertical_{filename}')
        os.remove(output_image_path)

        output_mask_path = os.path.join(path_masks, f'revert_vertical_{filename}')
        os.remove(output_mask_path)

    for filename in data_hor:
        output_image_path = os.path.join(path_images, f'revert_horizontal_{filename}')
        os.remove(output_image_path)

        output_mask_path = os.path.join(path_masks, f'revert_horizontal_{filename}')
        os.remove(output_mask_path)

In [13]:
def make_csv_files(folder_path, folder):
    images_folder = folder_path + "/" + folder + "images"
    masks_folder = folder_path + "/" + folder + "masks"

    images_files = os.listdir(images_folder)
    masks_files = os.listdir(masks_folder)

    image_paths = [os.path.join(folder + "images", file) for file in images_files]
    mask_paths = [os.path.join(folder + "masks", file) for file in masks_files]

    data = {'orig_image': image_paths, 'mask_image': mask_paths}
    df = pd.DataFrame(data)

    csv_file_path = "train_data.csv" if folder == "" else "test_data.csv"

    df.to_csv(csv_file_path, index=False)

In [14]:
def draw(orig_image, orig_masks, mask_image, intersec_mask):
    fig, axes = plt.subplots(1, 4)

    orig_image = orig_image.transpose(1, 2, 0)
    orig_image = (np.array(orig_image) - np.min(orig_image)) / (np.max(orig_image) - np.min(orig_image))
    axes[0].imshow(orig_image)
    axes[0].set_title('Original Image')

    axes[1].imshow(orig_masks)
    axes[1].set_title('Original Mask')
    
    axes[2].imshow(mask_image)
    axes[2].set_title('Predicted Mask')

    axes[3].imshow(intersec_mask)
    axes[3].set_title('Difference Mask')

    plt.tight_layout()
    plt.show()

In [15]:
def calculate_iou(predictions, targets):
    total_sum = 0.0
    for prediction, target in zip(predictions, targets):
        intersection = np.logical_and(prediction, target).sum().item()
        union = np.logical_or(prediction, target).sum().item()
        
        total_sum += intersection / union if union > 0 else 0.0

    return total_sum / len(predictions)

In [16]:
def ap_k(predictions, targets, k, klass):
    precision = 0
    tp = 0
    all_det = 0
    for i in range(k):
        predict = predictions[i]
        target = targets[i]

        if klass:
            tp += np.sum(predict[predict == target])
            all_det += np.sum(predict)
            precision += tp / all_det
        else:
            tp += len(predict[predict == target]) - np.count_nonzero(predict[predict == target])
            all_det += predict.size - np.count_nonzero(predict)
            precision += tp / all_det
    return precision / k 

def compute_ap(predictions, targets, k):
    return (ap_k(predictions, targets, k, True) + ap_k(predictions, targets, k, False)) / 2, ap_k(predictions, targets, k, True)

In [17]:
def get_bounding_box(ground_truth_map):
  bbox_array = []
  for ground_truth_mask in ground_truth_map:
    y_indices, x_indices = np.where(ground_truth_mask > 0)
    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)
    H, W = ground_truth_mask.shape
    x_min = max(0, x_min - np.random.randint(0, 20))
    x_max = min(W, x_max + np.random.randint(0, 20))
    y_min = max(0, y_min - np.random.randint(0, 20))
    y_max = min(H, y_max + np.random.randint(0, 20))
    bbox = [x_min, y_min, x_max, y_max]
    bbox_array.append(bbox)
  return bbox_array

In [18]:
# data_vertical, data_horizontal = augmentation()
#make_csv_files(folder_path, "")
#make_csv_files(folder_path, "test_")

In [19]:
train_images = os.listdir('20230530_segm_black_mouse_mnSLA_red_and_black_back/images')
train_masks = os.listdir('20230530_segm_black_mouse_mnSLA_red_and_black_back/masks')
test_images = os.listdir('20230530_segm_black_mouse_mnSLA_red_and_black_back/test_images')
test_masks = os.listdir('20230530_segm_black_mouse_mnSLA_red_and_black_back/test_masks')

In [2]:
from PIL import Image

def merge_tiff(input_files, output_file):
    images = [Image.open(file) for file in input_files]
    
    # Получаем размеры и режим первого изображения
    width, height = images[0].size
    mode = images[0].mode
    
    # Проверяем, что все изображения имеют одинаковые размеры и режим
    for img in images[1:]:
        if img.size != (width, height):
            img = img.resize((width, height), Image.ANTIALIAS)
            print(width, height)
    
    # Создаем новое изображение, к которому будем добавлять кадры
    result_image = Image.new(mode, (width, height * len(images)))

    # Сливаем каждое изображение в новое изображение
    for i, img in enumerate(images):
        result_image.paste(img, (0, i * height))

    # Сохраняем объединенное изображение
    result_image.save(output_file, format='TIFF', compression='tiff_lzw')
    # result_image.save(output_file)

    # Закрываем все изображения
    for img in images:
        img.close()

def get_tiff_files(folder):
    tiff_files = []
    for file in os.listdir(folder):
        if file.endswith(".tif") or file.endswith(".tiff"):
            tiff_files.append(os.path.join(folder, file))
    return tiff_files

folder_path = "tiff_data/images"

tiff_files = get_tiff_files(folder_path)

output_file = "images.tif"

merge_tiff(tiff_files, output_file)

OSError: encoder error -2 when writing image file

In [22]:
from datasets import Dataset
from PIL import Image
from sklearn.model_selection import train_test_split

train_images, val_images, train_labels, val_labels = train_test_split(train_images, train_masks, test_size=0.1, random_state=42)
val_images = val_images[:-1]
val_labels = val_labels[:-1]
train_images = train_images[:-1]
train_labels = train_labels[:-1]

train_dataset_dict = {
    "image": [Image.open(os.path.join('20230530_segm_black_mouse_mnSLA_red_and_black_back/images', img)).resize((256, 256)) for img in train_images],
    "label": [Image.open(os.path.join('20230530_segm_black_mouse_mnSLA_red_and_black_back/masks', mask)).resize((256, 256)) for mask in train_labels],
}

val_dataset_dict = {
    "image": [Image.open(os.path.join('20230530_segm_black_mouse_mnSLA_red_and_black_back/images', img)).resize((256, 256)) for img in val_images],
    "label": [Image.open(os.path.join('20230530_segm_black_mouse_mnSLA_red_and_black_back/masks', mask)).resize((256, 256)) for mask in val_labels],
}

test_dataset_dict = {
    "image": [Image.open(os.path.join('20230530_segm_black_mouse_mnSLA_red_and_black_back/test_images', img)).resize((256, 256)) for img in test_images],
    "label": [Image.open(os.path.join('20230530_segm_black_mouse_mnSLA_red_and_black_back/test_masks', mask)).resize((256, 256)) for mask in test_masks],
}

train_set = Dataset.from_dict(train_dataset_dict)
val_set = Dataset.from_dict(val_dataset_dict)
test_set = Dataset.from_dict(test_dataset_dict)

KeyboardInterrupt: 

In [None]:
train_set, val_set, test_set

(Dataset({
     features: ['image', 'label'],
     num_rows: 1104
 }),
 Dataset({
     features: ['image', 'label'],
     num_rows: 122
 }),
 Dataset({
     features: ['image', 'label'],
     num_rows: 49
 }))

In [None]:
class SAMDataset(Dataset):
  def __init__(self, dataset, processor):
    self.dataset = dataset
    self.processor = processor

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    image = item["image"]
    ground_truth_mask = np.array(item["label"])

    ground_truth_mask[ground_truth_mask > 0] = 1
    
    prompt = get_bounding_box(ground_truth_mask)

    inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")

    inputs = {k:v.squeeze(0) for k,v in inputs.items()}

    inputs["ground_truth_mask"] = ground_truth_mask

    return inputs

In [None]:
def learning(num_epochs, train_load, val_load, model, optimizer, criterion, model_name, scheduler=None):
  train_losses = []
  val_losses = []
  iou_test = []
  max_iou = 0.0

  for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch in tqdm(train_load):
        outputs = model(pixel_values=batch["pixel_values"].to(device),
                      input_boxes=batch["input_boxes"].unsqueeze(1).to(device),
                      multimask_output=False)
        predicted_masks = outputs.pred_masks.squeeze(1)
        predicted_masks = torch.sigmoid(predicted_masks)
        ground_truth_masks = batch["ground_truth_mask"].float().to(device)
        optimizer.zero_grad()

        loss = criterion(ground_truth_masks, predicted_masks.squeeze(1))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_losses.append(train_loss/len(train_load))
    
    if scheduler is not None:
      scheduler.step(loss)

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_load):
          outputs = model(pixel_values=batch["pixel_values"].to(device),
                      input_boxes=batch["input_boxes"].unsqueeze(1).to(device),
                      multimask_output=False)
          predicted_masks = outputs.pred_masks.squeeze(1)
          predicted_masks = torch.sigmoid(predicted_masks)
          ground_truth_masks = batch["ground_truth_mask"].float().to(device)
          loss = criterion(predicted_masks, ground_truth_masks.unsqueeze(1))
          val_loss += (loss.item())
    val_losses.append(val_loss/len(val_load))

    predictions, _, orig_masks, _ = prediction(model, test_loader)
    iou = calculate_iou(predictions, orig_masks)
    iou_test.append(iou)

    if iou > max_iou:
       max_iou = iou
       torch.save(model.state_dict(), f"models/{model_name}_IOU-{iou}.pth")

    print(f"Epoch [{epoch+1}/{num_epochs}], Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}, IOU: {iou:.4f}")
  return model, train_losses, val_losses, iou_test

In [None]:
def prediction(model, loader):
    model.eval()
    predictions = []
    orig_images = []
    orig_masks = []
    intersection_masks = []
    size = (544, 928)
    threshold = 0
    transform = transforms.Compose([transforms.Resize(size)])
    with torch.no_grad():
        for batch in tqdm(loader):
            outputs = model(pixel_values=batch["pixel_values"].to(device),
                      input_boxes=batch["input_boxes"].unsqueeze(1).to(device),
                      multimask_output=False)
            predicted_masks = outputs.pred_masks.squeeze(1)
            ground_truth_masks = batch["ground_truth_mask"].float().to(device).unsqueeze(1)
            #predicted_masks = torch.sigmoid(predicted_masks)
            #new_outputs = torch.zeros(prediction_masks.shape[0], prediction_masks.shape[1], prediction_masks.shape[2], prediction_masks.shape[3])
            predicted_masks[predicted_masks < threshold] = 0
            predicted_masks[predicted_masks > threshold] = 1

            # for i in range(outputs.shape[0]):
            #     for j in range(outputs.shape[2]):
            #         for k in range(outputs.shape[3]):
            #             if outputs[i, 0, j, k] == 1:
            #                 for m in range(-1, 2):
            #                     for n in range(-1, 2):
            #                         if 0 <= j + m < outputs.shape[2] - 4 and 0 <= k + n < outputs.shape[3] - 4:
            #                             new_outputs[i, 0, j + 1 + m, k + 1 + n] = 1

            # outputs = new_outputs

            predicted_masks = transform(predicted_masks)
            orig_image = transform(batch["pixel_values"].float())
            orig_mask = transform(ground_truth_masks)

            predictions.append(predicted_masks.cpu().numpy())
            orig_images.append(orig_image.cpu().numpy())
            orig_masks.append(orig_mask.cpu().numpy())

            intersection = np.abs(predicted_masks.cpu().numpy() - orig_mask.cpu().numpy())
            intersection_masks.append(intersection)
    predictions = np.concatenate(predictions, axis=0).squeeze()
    orig_images = np.concatenate(orig_images, axis=0).squeeze()
    orig_masks = np.concatenate(orig_masks, axis=0).squeeze()
    intersection_masks = np.concatenate(intersection_masks, axis=0).squeeze()
    return predictions, orig_images, orig_masks, intersection_masks

In [None]:
def validation(model, loader, images_to_draw):
    predictions, orig_images, orig_masks, intersection_masks = prediction(model, loader)

    iou = calculate_iou(predictions, orig_masks)
    apk, ap = compute_ap(predictions, orig_masks, len(predictions))

    print(f"IOU: {iou}")
    print(f"AP for Two Classes: {apk}")
    print(f"AP for Mouse Class: {ap}")

    for i in range(images_to_draw):
        draw(orig_images[i], orig_masks[i], predictions[i], intersection_masks[i])

In [None]:
from transformers import SamProcessor
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

In [None]:
batch_size = 2

train_dataset = SAMDataset(dataset=train_set, processor=processor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_dataset = SAMDataset(dataset=val_set, processor=processor)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
test_dataset = SAMDataset(dataset=test_set, processor=processor)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

In [None]:
from transformers import SamModel

model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
model_name = "sam-vit-base"

for name, param in model.named_parameters():
    if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
        param.requires_grad_(False)

In [None]:
from torch.optim import Adam
import monai 

lr = 1e-5

optimizer = Adam(model.mask_decoder.parameters(), lr=lr, weight_decay=0)
criterion = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

num_epochs = 10

In [None]:
model, train_losses, val_losses, iou_test = learning(num_epochs, train_loader, val_loader, model, optimizer, criterion, model_name)

100%|██████████| 552/552 [03:46<00:00,  2.44it/s]
100%|██████████| 61/61 [00:24<00:00,  2.50it/s]
100%|██████████| 24/24 [00:09<00:00,  2.42it/s]


Epoch [1/10], Train loss: 2300.9990, Val loss: 97.8408, IOU: 0.0000


 95%|█████████▍| 523/552 [03:35<00:11,  2.43it/s]


KeyboardInterrupt: 

In [None]:
plt.subplot(3, 1, 1)
plt.plot(train_losses, label='Training Losses', color='blue')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Losses')
plt.legend()

plt.subplot(3, 1, 2)
plt.plot(val_losses, label='Validation Losses', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Validation Losses')
plt.legend()

plt.subplot(3, 1, 3)
plt.plot(iou_test, label='IOU test', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('IOU Losses')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
validation(model, train_loader, 1)

In [None]:
validation(model, val_loader, 3)

In [None]:
validation(model, test_loader, 10)