In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from PIL import Image
import os
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import random
import torch.nn.functional as F
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import cv2
import segmentation_models_pytorch as smp

folder_path = "20230530_segm_black_mouse_mnSLA_red_and_black_back"

seed_value = 52
torch.manual_seed(seed_value)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [2]:
def augmentation():
    count_image = 350
    path_images = "20230530_segm_black_mouse_mnSLA_red_and_black_back/images"
    path_masks = "20230530_segm_black_mouse_mnSLA_red_and_black_back/masks"

    image_files = os.listdir(path_images)
    selected_files_for_vertical = random.sample(image_files, count_image)

    for filename in selected_files_for_vertical:
        image_path = os.path.join(path_images, filename)
        image = Image.open(image_path)
        rotated_image = image.transpose(Image.FLIP_LEFT_RIGHT)
        output_path = os.path.join(path_images, f'revert_vertical_{filename}')
        rotated_image.save(output_path)
        image.close()

        mask_path = os.path.join(path_masks, filename)
        mask = Image.open(mask_path)
        rotated_mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
        output_path = os.path.join(path_masks, f'revert_vertical_{filename}')
        rotated_mask.save(output_path)
        mask.close()

    # selected_files_for_horizontal = []
    count_image = 1
    selected_files_for_horizontal = random.sample(image_files, count_image)

    for filename in selected_files_for_horizontal:
        image_path = os.path.join(path_images, filename)
        image = Image.open(image_path)
        rotated_image = image.transpose(Image.FLIP_TOP_BOTTOM)
        output_path = os.path.join(path_images, f'revert_horizontal_{filename}')
        rotated_image.save(output_path)
        image.close()

        mask_path = os.path.join(path_masks, filename)
        mask = Image.open(mask_path)
        rotated_mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
        output_path = os.path.join(path_masks, f'revert_horizontal_{filename}')
        rotated_mask.save(output_path)
        mask.close()

    return selected_files_for_vertical, selected_files_for_horizontal


def delete_generated_images(data_vert, data_hor):
    path_images = "20230530_segm_black_mouse_mnSLA_red_and_black_back/images"
    path_masks = "20230530_segm_black_mouse_mnSLA_red_and_black_back/masks"

    for filename in data_vert:
        output_image_path = os.path.join(path_images, f'revert_vertical_{filename}')
        os.remove(output_image_path)

        output_mask_path = os.path.join(path_masks, f'revert_vertical_{filename}')
        os.remove(output_mask_path)

    for filename in data_hor:
        output_image_path = os.path.join(path_images, f'revert_horizontal_{filename}')
        os.remove(output_image_path)

        output_mask_path = os.path.join(path_masks, f'revert_horizontal_{filename}')
        os.remove(output_mask_path)

In [3]:
def make_csv_files(folder_path, folder):
    images_folder = folder_path + "/" + folder + "images"
    masks_folder = folder_path + "/" + folder + "masks"

    images_files = os.listdir(images_folder)
    masks_files = os.listdir(masks_folder)

    image_paths = [os.path.join(folder + "images", file) for file in images_files]
    mask_paths = [os.path.join(folder + "masks", file) for file in masks_files]

    data = {'orig_image': image_paths, 'mask_image': mask_paths}
    df = pd.DataFrame(data)

    csv_file_path = "train_data.csv" if folder == "" else "test_data.csv"

    df.to_csv(csv_file_path, index=False)

In [4]:
data_vertical, data_horizontal = augmentation()
make_csv_files(folder_path, "")
make_csv_files(folder_path, "test_")

In [9]:
train_df = pd.read_csv("train_data.csv")
test = pd.read_csv("test_data.csv")

In [10]:
def draw(orig_image, orig_masks, mask_image, intersec_mask):
    fig, axes = plt.subplots(1, 4)

    orig_image = orig_image.transpose(1, 2, 0)
    orig_image = (np.array(orig_image) - np.min(orig_image)) / (np.max(orig_image) - np.min(orig_image))
    axes[0].imshow(orig_image)
    axes[0].set_title('Original Image')

    axes[1].imshow(orig_masks)
    axes[1].set_title('Original Mask')
    
    axes[2].imshow(mask_image)
    axes[2].set_title('Predicted Mask')

    axes[3].imshow(intersec_mask)
    axes[3].set_title('Difference Mask')

    plt.tight_layout()
    plt.show()

In [11]:
def calculate_iou(predictions, targets):
    total_sum = 0.0
    for prediction, target in zip(predictions, targets):
        intersection = np.logical_and(prediction, target).sum().item()
        union = np.logical_or(prediction, target).sum().item()
        
        total_sum += intersection / union if union > 0 else 0.0

    return total_sum / len(predictions)

In [12]:
def ap_k(predictions, targets, k, klass):
    precision = 0
    tp = 0
    all_det = 0
    for i in range(k):
        predict = predictions[i]
        target = targets[i]

        if klass:
            tp += np.sum(predict[predict == target])
            all_det += np.sum(predict)
            precision += tp / all_det
        else:
            tp += len(predict[predict == target]) - np.count_nonzero(predict[predict == target])
            all_det += predict.size - np.count_nonzero(predict)
            precision += tp / all_det
    return precision / k 

def compute_ap(predictions, targets, k):
    return (ap_k(predictions, targets, k, True) + ap_k(predictions, targets, k, False)) / 2, ap_k(predictions, targets, k, True)

In [13]:
class ImagesDataset(Dataset):
    def __init__(self, folder, data, transform_image, transform_mask, size):
      self.folder = folder
      self.data = data.copy()
      self.orig_image_paths = [os.path.join(folder, filename) for filename in data['orig_image'].copy()]
      self.mask_image_paths = [os.path.join(folder, filename) for filename in data['mask_image'].copy()]
      self.transform_image = transform_image
      self.transform_mask = transform_mask
      self.size = size

    def __len__(self):
        return len(self.orig_image_paths)

    def __getitem__(self, idx):
        orig_image_path = self.orig_image_paths[idx]
        mask_image_path = self.mask_image_paths[idx]
        orig_image = Image.open(orig_image_path).convert('RGB')
        mask_image = Image.open(mask_image_path).convert('L')
        
        orig_image = self.transform_image(orig_image)
        orig_image = orig_image.to(orig_image)
        
        mask_image = self.transform_mask(mask_image)
        mask_image = mask_image.to(mask_image)
        
        return {'image': orig_image.float().to(device), 'mask_input': mask_image.float().to(device), 'original_size': self.size}

In [14]:
def collate_fn(batch):
    return batch

In [28]:
from sklearn.model_selection import train_test_split

train, val = train_test_split(train_df, test_size=0.1 , random_state=42)

size = (32, 54)
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
batch_size = 8

transform_image = transforms.Compose([transforms.Resize(size), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])

transform_mask = transforms.Compose([transforms.Resize(size), transforms.ToTensor()])

train_dataset = ImagesDataset(folder_path, train, transform_image, transform_mask, size)
val_dataset = ImagesDataset(folder_path, val, transform_image, transform_mask, size)
test_dataset = ImagesDataset(folder_path, test, transform_image, transform_mask, size)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [30]:
def learning(num_epochs, train_load, val_load, model, optimizer, criterion, model_name, scheduler=None):
  train_losses = []
  val_losses = []
  iou_test = []
  max_iou = 0.0

  for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch in tqdm(train_load):
        optimizer.zero_grad()
        # outputs = model(batch, multimask_output=False)
        outputs = model(pixel_values=batch["image"],
                        input_boxes=batch["mask_input"],
                        multimask_output=False)
        # prediction_masks = torch.stack([x["masks"] for x in outputs], dim=0)
        orig_masks = torch.stack([x["mask_input"] for x in batch], dim=0)
        prediction_masks = outputs.pred_masks.squeeze(1)
        loss = criterion(orig_masks, prediction_masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_losses.append(train_loss/len(train_load))
    
    if scheduler is not None:
      scheduler.step(loss)

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_load):
          outputs = model(batch, multimask_output=False)
          prediction_masks = torch.stack([x["masks"] for x in outputs], dim=0)
          orig_masks = torch.stack([x["mask_input"] for x in batch], dim=0)
          loss = criterion(orig_masks, prediction_masks)
          val_loss += (loss.item())
    val_losses.append(val_loss/len(val_load))

    predictions, _, orig_masks, _ = prediction(model, test_loader)
    iou = calculate_iou(predictions, orig_masks)
    iou_test.append(iou)

    if iou > max_iou:
       max_iou = iou
       torch.save(model.state_dict(), f"models/{model_name}_IOU-{iou}.pth")

    print(f"Epoch [{epoch+1}/{num_epochs}], Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}, IOU: {iou:.4f}")
  return model, train_losses, val_losses, iou_test

In [11]:
def prediction(model, loader):
    model.eval()
    predictions = []
    orig_images = []
    orig_masks = []
    intersection_masks = []
    size = (544, 928)
    transform = transforms.Compose([transforms.Resize(size)])
    with torch.no_grad():
        for batch in tqdm(loader):
            outputs = model(batch, multimask_output=False)
            prediction_masks = torch.stack([x["masks"] for x in outputs], dim=0)
            orig_images = torch.stack([x["images"] for x in batch], dim=0)
            orig_masks = torch.stack([x["mask_input"] for x in batch], dim=0)
            #new_outputs = torch.zeros(prediction_masks.shape[0], prediction_masks.shape[1], prediction_masks.shape[2], prediction_masks.shape[3])
            prediction_masks[prediction_masks < 0] = 0
            prediction_masks[prediction_masks > 0] = 1

            # for i in range(outputs.shape[0]):
            #     for j in range(outputs.shape[2]):
            #         for k in range(outputs.shape[3]):
            #             if outputs[i, 0, j, k] == 1:
            #                 for m in range(-1, 2):
            #                     for n in range(-1, 2):
            #                         if 0 <= j + m < outputs.shape[2] - 4 and 0 <= k + n < outputs.shape[3] - 4:
            #                             new_outputs[i, 0, j + 1 + m, k + 1 + n] = 1

            # outputs = new_outputs

            prediction_masks = transform(prediction_masks)
            orig_images = transform(orig_images)
            orig_masks = transform(orig_masks)

            predictions.append(prediction_masks.cpu().numpy())
            orig_images.append(orig_images.cpu().numpy())
            orig_masks.append(orig_masks.cpu().numpy())

            intersection = np.abs(prediction_masks.cpu().numpy() - orig_masks.cpu().numpy())
            intersection_masks.append(intersection)
    predictions = np.concatenate(predictions, axis=0).squeeze()
    orig_images = np.concatenate(orig_images, axis=0).squeeze()
    orig_masks = np.concatenate(orig_masks, axis=0).squeeze()
    intersection_masks = np.concatenate(intersection_masks, axis=0).squeeze()
    return predictions, orig_images, orig_masks, intersection_masks

In [17]:
def validation(model, loader, images_to_draw):
    predictions, orig_images, orig_masks, intersection_masks = prediction(model, loader)

    iou = calculate_iou(predictions, orig_masks)
    apk, ap = compute_ap(predictions, orig_masks, len(predictions))

    print(f"IOU: {iou}")
    print(f"AP for Two Classes: {apk}")
    print(f"AP for Mouse Class: {ap}")

    for i in range(images_to_draw):
        draw(orig_images[i], orig_masks[i], predictions[i], intersection_masks[i])

In [13]:
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import cv2
import segmentation_models_pytorch as smp

sam = sam_model_registry["vit_h"](checkpoint='models/sam_vit_h_4b8939.pth')
sam.cuda()
model_name = 'sam_vit_h_4b8939'
for param in sam.parameters():
    param.requires_grad=False


In [22]:
from transformers import SamModel

sam = SamModel.from_pretrained("facebook/sam-vit-base")
model_name = "sam-vit-base"

for name, param in sam.named_parameters():
    if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
        param.requires_grad_(False)

In [20]:
from torch.optim import Adam
import monai 

optimizer = Adam(sam.mask_decoder.parameters(), lr=1e-5, weight_decay=0)
criterion = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

In [None]:
from statistics import mean
from torch.nn.functional import threshold, normalize

In [14]:
# import segmentation_models_pytorch as smp

# learning_rate = 0.001
# optimizer = optim.Adamax(sam.parameters(), lr=learning_rate)
# criterion = smp.losses.DiceLoss(mode='binary')
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

In [31]:
sam.to(device)
num_epochs = 5
sam.train()
sam, train_losses, val_losses, iou_test = learning(num_epochs, train_loader, val_loader, sam, optimizer, criterion, model_name)

  0%|          | 0/139 [00:00<?, ?it/s]


ValueError: ('The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.', ' got torch.Size([8, 1, 32, 54]).')

In [None]:
plt.subplot(3, 1, 1)
plt.plot(train_losses, label='Training Losses', color='blue')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Losses')
plt.legend()

plt.subplot(3, 1, 2)
plt.plot(val_losses, label='Validation Losses', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Validation Losses')
plt.legend()

plt.subplot(3, 1, 3)
plt.plot(iou_test, label='IOU test', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('IOU Losses')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
validation(model, train_loader, 1)

In [None]:
validation(model, val_loader, 3)

In [None]:
validation(model, test_loader, 10)