# НУЖНО ВЫБРАТЬ МЕТРИКУ И ОБУЧИТЬ МОДЕЛЬ

In [63]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from PIL import Image
import os
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import sklearn
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import DiceLoss

folder_path = "data/20230904_segm_rat_OFT_gray_back/20230904_segm_rat_OFT_gray_back"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

train_df = pd.read_csv("train_data.csv")
test = pd.read_csv("test_data.csv")

cpu


In [64]:
# !pip install segmentation-models-pytorch

In [65]:
# import os
# import pandas as pd

# train_images_folder = folder_path + "/test_images"
# train_masks_folder = folder_path + "/test_masks"

# train_images_files = os.listdir(train_images_folder)
# train_masks_files = os.listdir(train_masks_folder)

# image_paths = [os.path.join("test_images", file) for file in train_images_files]
# mask_paths = [os.path.join("test_masks", file) for file in train_masks_files]

# data = {'orig_image': image_paths, 'mask_image': mask_paths}
# df = pd.DataFrame(data)

# csv_file_path = "test_data.csv"

# df.to_csv(csv_file_path, index=False)

# print("CSV file created successfully!")

In [66]:
def draw(orig_image, mask_image):
    fig, axes = plt.subplots(1, 2)

    axes[0].imshow(orig_image)
    axes[0].set_title('Original Image')
    
    axes[1].imshow(mask_image)
    axes[1].set_title('Mask Image')

    plt.tight_layout()
    plt.show()

### Preprocessing (подготовка данных)

In [67]:
class ImagesDataset(Dataset):
    def __init__(self, folder, data, transform_image, transform_mask):
      self.folder = folder
      self.data = data.copy()
      self.orig_image_paths = [os.path.join(folder, filename) for filename in data['orig_image'].copy()]
      self.mask_image_paths = [os.path.join(folder, filename) for filename in data['mask_image'].copy()]
      self.transform_image = transform_image
      self.transform_mask = transform_mask

    def __len__(self):
        return len(self.orig_image_paths)

    def __getitem__(self, idx):
        orig_image_path = self.orig_image_paths[idx]
        mask_image_path = self.mask_image_paths[idx]
        orig_image = Image.open(orig_image_path).convert('RGB')
        mask_image = Image.open(mask_image_path).convert('L')

        orig_image = self.transform_image(orig_image)
        orig_image = orig_image.to(orig_image)
        
        mask_image = self.transform_mask(mask_image)
        mask_image = mask_image.to(mask_image)

        return orig_image.float(), mask_image.float()

In [68]:
class ReplicateChannel:
    def __call__(self, img):
        return img.repeat(3, 1, 1)

In [69]:
from sklearn.model_selection import train_test_split

train, val = train_test_split(train_df, test_size=0.1 , random_state=42)

size = (320, 320)
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
batch_size = 8

ENCODER = 'timm-efficientnet-b0'
WEIGHTS = 'imagenet'

transform_image = transforms.Compose([transforms.Resize(size), transforms.ToTensor(),
                                 transforms.Normalize(mean=mean, std=std)])

transform_mask = transforms.Compose([transforms.Resize(size), transforms.ToTensor(),
                                 transforms.Normalize(mean=mean[0], std=std[0])])

train_dataset = ImagesDataset(folder_path, train, transform_image, transform_mask)
val_dataset = ImagesDataset(folder_path, val, transform_image, transform_mask)
test_dataset = ImagesDataset(folder_path, test, transform_image, transform_mask)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

### Training/evaluation loop

In [70]:
def learning(num_epochs, train_load, val_load, model, optimizer, criterion):
  train_losses = []
  val_losses = []

  for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for x_batch, y_batch in tqdm(train_load):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        outputs = model(x_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_losses.append(train_loss/len(train_load))

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x_batch, y_batch in tqdm(val_load):
          x_batch, y_batch = x_batch.to(device), y_batch.to(device)
          outputs = model(x_batch)
          loss = criterion(outputs, y_batch)
          val_loss += (loss.item())
    val_losses.append(val_loss/len(val_load))

    print(f"Epoch [{epoch+1}/{num_epochs}], Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}")
  return model, train_losses, val_losses

### Prediction function

In [71]:
def prediction(model, loader):
  model.eval()
  predictions = []
  orig_images = []
  with torch.no_grad():
    for x_batch, y_batch in tqdm(loader):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        outputs = model(x_batch)
        predictions.append(outputs.cpu().numpy())
        orig_images.append(outputs.cpu().numpy())
  predictions = np.concatenate(predictions, axis=0).flatten()
  orig_images = np.concatenate(orig_images, axis=0).flatten()
  return predictions, orig_images

In [72]:
def validation(model, loader, images_to_draw):
    predictions, orig_images = prediction(model, loader)

    loss = DiceLoss(orig_images, predictions)
    print(loss)
    
    for i in range(images_to_draw):
        draw(orig_images[i], predictions[i])


### Experiments

In [73]:
class SegmentationModel(nn.Module):
    def __init__(self):
        super(SegmentationModel, self).__init__()

        self.model = smp.UnetPlusPlus(
            encoder_name=ENCODER,
            encoder_weights=WEIGHTS,
            in_channels=3,
            classes=1,
            activation=None)

    def forward(self, images, masks=None):
        logits = self.model(images)

        if masks != None:
            loss1 = DiceLoss(mode='binary')(logits, masks)
            loss2 = nn.BCEWithLogitsLoss()(logits, masks)
            return logits, loss1 + loss2

        return logits

In [74]:
model = SegmentationModel()
model.to(device)

SegmentationModel(
  (model): UnetPlusPlus(
    (encoder): EfficientNetEncoder(
      (conv_stem): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNormAct2d(
        32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
        (drop): Identity()
        (act): Swish()
      )
      (blocks): Sequential(
        (0): Sequential(
          (0): DepthwiseSeparableConv(
            (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (bn1): BatchNormAct2d(
              32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
              (drop): Identity()
              (act): Swish()
            )
            (se): SqueezeExcite(
              (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
              (act1): Swish()
              (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
              (gate): Sigmoid()
            )
 

In [75]:
num_epochs = 40
learning_rate = 0.001
optimizer = optim.Adamax(model.parameters(), lr=learning_rate)
criterion = DiceLoss(mode='binary')

model, train_losses, val_losses = learning(num_epochs, train_loader, val_loader, model, optimizer, criterion)

plt.subplot(2, 1, 1)
plt.plot(train_losses, label='Training Losses', color='blue')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Losses')
plt.legend()
plt.subplot(2, 1, 2)
plt.plot(val_losses, label='Validation Losses', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Validation Losses')
plt.legend()
plt.tight_layout()
plt.show()

  2%|▏         | 1/64 [00:16<17:45, 16.91s/it]


KeyboardInterrupt: 

### Evaluation (оценка качества модели)

In [None]:
validation(model, val_loader, 10)

In [None]:
validation(model, test_loader, 10)